32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/MathExtras.h"
42 #define DEBUG_TYPE "mlir-spirv-conversion"
52 static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
53 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
54 if (vecType.isScalable()) {
55 LLVM_DEBUG(llvm::dbgs()
56 <<
"--scalable vectors are not supported -> BAIL\n");
63 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
67 if (!maybeShapeRatio) {
68 LLVM_DEBUG(llvm::dbgs()
69 <<
"--could not compute integral shape ratio -> BAIL\n");
72 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
73 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
76 LLVM_DEBUG(llvm::dbgs()
77 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
87 template <
typename LabelT>
88 static LogicalResult checkExtensionRequirements(
91 for (
const auto &ors : candidates) {
97 for (spirv::Extension ext : ors)
98 extStrings.push_back(spirv::stringifyExtension(ext));
100 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
101 << llvm::join(extStrings,
", ")
102 <<
"] but none allowed in target environment\n";
115 template <
typename LabelT>
116 static LogicalResult checkCapabilityRequirements(
119 for (
const auto &ors : candidates) {
120 if (targetEnv.
allows(ors))
125 for (spirv::Capability cap : ors)
126 capStrings.push_back(spirv::stringifyCapability(cap));
128 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
129 << llvm::join(capStrings,
", ")
130 <<
"] but none allowed in target environment\n";
139 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
140 switch (storageClass) {
141 case spirv::StorageClass::PhysicalStorageBuffer:
142 case spirv::StorageClass::PushConstant:
143 case spirv::StorageClass::StorageBuffer:
144 case spirv::StorageClass::Uniform:
154 wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
155 auto structType = needsExplicitLayout(storageClass)
167 return cast<spirv::ScalarType>(
173 static std::optional<int64_t>
175 if (isa<spirv::ScalarType>(type)) {
188 if (
auto complexType = dyn_cast<ComplexType>(type)) {
189 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
192 return 2 * *elementSize;
195 if (
auto vecType = dyn_cast<VectorType>(type)) {
196 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
199 return vecType.getNumElements() * *elementSize;
202 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
207 if (!memRefType.hasStaticShape() ||
208 failed(memRefType.getStridesAndOffset(strides, offset)))
214 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
218 if (memRefType.getRank() == 0)
221 auto dims = memRefType.getShape();
222 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
223 ShapedType::isDynamic(offset) ||
224 llvm::is_contained(strides, ShapedType::kDynamic))
227 int64_t memrefSize = -1;
228 for (
const auto &shape :
enumerate(dims))
229 memrefSize =
std::max(memrefSize, shape.value() * strides[shape.index()]);
231 return (offset + memrefSize) * *elementSize;
234 if (
auto tensorType = dyn_cast<TensorType>(type)) {
235 if (!tensorType.hasStaticShape())
238 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
242 int64_t size = *elementSize;
243 for (
auto shape : tensorType.getShape())
257 std::optional<spirv::StorageClass> storageClass = {}) {
261 type.getExtensions(extensions, storageClass);
262 type.getCapabilities(capabilities, storageClass);
265 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
266 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
271 if (!
options.emulateLT32BitScalarTypes)
276 LLVM_DEBUG(llvm::dbgs()
278 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
282 if (
auto floatType = dyn_cast<FloatType>(type)) {
283 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
287 auto intType = cast<IntegerType>(type);
288 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
290 intType.getSignedness());
303 if (type.getWidth() > 8) {
304 LLVM_DEBUG(llvm::dbgs() <<
"not a subbyte type\n");
308 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
312 if (!llvm::isPowerOf2_32(type.getWidth())) {
313 LLVM_DEBUG(llvm::dbgs()
314 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
319 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
321 type.getSignedness());
328 convertIndexElementType(ShapedType type,
330 Type indexType = dyn_cast<IndexType>(type.getElementType());
341 std::optional<spirv::StorageClass> storageClass = {}) {
342 type = cast<VectorType>(convertIndexElementType(type,
options));
343 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
347 auto intType = dyn_cast<IntegerType>(type.getElementType());
349 LLVM_DEBUG(llvm::dbgs()
351 <<
" illegal: cannot convert non-scalar element type\n");
355 Type elementType = convertSubByteIntegerType(
options, intType);
359 if (type.getRank() <= 1 && type.getNumElements() == 1)
362 if (type.getNumElements() > 4) {
363 LLVM_DEBUG(llvm::dbgs()
364 << type <<
" illegal: > 4-element unimplemented\n");
371 if (type.getRank() <= 1 && type.getNumElements() == 1)
372 return convertScalarType(targetEnv,
options, scalarType, storageClass);
375 LLVM_DEBUG(llvm::dbgs()
376 << type <<
" illegal: not a valid composite type\n");
383 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
384 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
387 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
388 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
392 convertScalarType(targetEnv,
options, scalarType, storageClass);
401 std::optional<spirv::StorageClass> storageClass = {}) {
402 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
404 LLVM_DEBUG(llvm::dbgs()
405 << type <<
" illegal: cannot convert non-scalar element type\n");
410 convertScalarType(targetEnv,
options, scalarType, storageClass);
413 if (elementType != type.getElementType()) {
414 LLVM_DEBUG(llvm::dbgs()
415 << type <<
" illegal: complex type emulation unsupported\n");
432 if (!type.hasStaticShape()) {
433 LLVM_DEBUG(llvm::dbgs()
434 << type <<
" illegal: dynamic shape unimplemented\n");
438 type = cast<TensorType>(convertIndexElementType(type,
options));
439 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
441 LLVM_DEBUG(llvm::dbgs()
442 << type <<
" illegal: cannot convert non-scalar element type\n");
446 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
447 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
448 if (!scalarSize || !tensorSize) {
449 LLVM_DEBUG(llvm::dbgs()
450 << type <<
" illegal: cannot deduce element count\n");
454 int64_t arrayElemCount = *tensorSize / *scalarSize;
455 if (arrayElemCount == 0) {
456 LLVM_DEBUG(llvm::dbgs()
457 << type <<
" illegal: cannot handle zero-element tensors\n");
461 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
464 std::optional<int64_t> arrayElemSize =
465 getTypeNumBytes(
options, arrayElemType);
466 if (!arrayElemSize) {
467 LLVM_DEBUG(llvm::dbgs()
468 << type <<
" illegal: cannot deduce converted element size\n");
478 spirv::StorageClass storageClass) {
479 unsigned numBoolBits =
options.boolNumBits;
480 if (numBoolBits != 8) {
481 LLVM_DEBUG(llvm::dbgs()
482 <<
"using non-8-bit storage for bool types unimplemented");
485 auto elementType = dyn_cast<spirv::ScalarType>(
490 convertScalarType(targetEnv,
options, elementType, storageClass);
493 std::optional<int64_t> arrayElemSize =
494 getTypeNumBytes(
options, arrayElemType);
495 if (!arrayElemSize) {
496 LLVM_DEBUG(llvm::dbgs()
497 << type <<
" illegal: cannot deduce converted element size\n");
501 if (!type.hasStaticShape()) {
504 if (targetEnv.
allows(spirv::Capability::Kernel))
506 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
510 return wrapInStructAndGetPointer(arrayType, storageClass);
513 if (type.getNumElements() == 0) {
514 LLVM_DEBUG(llvm::dbgs()
515 << type <<
" illegal: zero-element memrefs are not supported\n");
519 int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
521 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
523 if (targetEnv.
allows(spirv::Capability::Kernel))
525 return wrapInStructAndGetPointer(arrayType, storageClass);
531 spirv::StorageClass storageClass) {
532 IntegerType elementType = cast<IntegerType>(type.getElementType());
533 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
536 int64_t arrayElemSize = *getTypeNumBytes(
options, arrayElemType);
538 if (!type.hasStaticShape()) {
541 if (targetEnv.
allows(spirv::Capability::Kernel))
543 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
547 return wrapInStructAndGetPointer(arrayType, storageClass);
550 if (type.getNumElements() == 0) {
551 LLVM_DEBUG(llvm::dbgs()
552 << type <<
" illegal: zero-element memrefs are not supported\n");
559 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
561 if (targetEnv.
allows(spirv::Capability::Kernel))
563 return wrapInStructAndGetPointer(arrayType, storageClass);
569 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
574 <<
" illegal: expected memory space to be a SPIR-V storage class "
575 "attribute; please use MemorySpaceToStorageClassConverter to map "
576 "numeric memory spaces beforehand\n");
579 spirv::StorageClass storageClass = attr.getValue();
581 if (isa<IntegerType>(type.getElementType())) {
582 if (type.getElementTypeBitWidth() == 1)
583 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
584 if (type.getElementTypeBitWidth() < 8)
585 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
589 Type elementType = type.getElementType();
590 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
592 convertVectorType(targetEnv,
options, vecType, storageClass);
593 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
595 convertComplexType(targetEnv,
options, complexType, storageClass);
596 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
598 convertScalarType(targetEnv,
options, scalarType, storageClass);
599 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
600 type = cast<MemRefType>(convertIndexElementType(type,
options));
601 arrayElemType = type.getElementType();
606 <<
" unhandled: can only convert scalar or vector element type\n");
612 std::optional<int64_t> arrayElemSize =
613 getTypeNumBytes(
options, arrayElemType);
614 if (!arrayElemSize) {
615 LLVM_DEBUG(llvm::dbgs()
616 << type <<
" illegal: cannot deduce converted element size\n");
620 if (!type.hasStaticShape()) {
623 if (targetEnv.
allows(spirv::Capability::Kernel))
625 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
629 return wrapInStructAndGetPointer(arrayType, storageClass);
632 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
634 LLVM_DEBUG(llvm::dbgs()
635 << type <<
" illegal: cannot deduce element count\n");
639 if (*memrefSize == 0) {
640 LLVM_DEBUG(llvm::dbgs()
641 << type <<
" illegal: zero-element memrefs are not supported\n");
646 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
648 if (targetEnv.
allows(spirv::Capability::Kernel))
650 return wrapInStructAndGetPointer(arrayType, storageClass);
674 if (inputs.size() != 1) {
675 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
678 Value input = inputs.front();
681 if (!isa<IntegerType>(type)) {
682 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
685 auto inputType = cast<IntegerType>(input.
getType());
687 auto scalarType = dyn_cast<spirv::ScalarType>(type);
689 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
696 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
697 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
703 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
704 return builder.
create<spirv::IEqualOp>(loc, input, one);
710 scalarType.getExtensions(exts);
711 scalarType.getCapabilities(caps);
712 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
713 failed(checkExtensionRequirements(type, targetEnv, exts))) {
714 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
722 return builder.
create<spirv::SConvertOp>(loc, type, input);
724 return builder.
create<spirv::UConvertOp>(loc, type, input);
731 static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
732 spirv::BuiltIn builtin) {
735 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
736 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
737 spirv::SPIRVDialect::getAttributeName(
738 spirv::Decoration::BuiltIn))) {
739 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
740 if (varBuiltIn == builtin) {
749 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
751 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
755 static spirv::GlobalVariableOp
756 getOrInsertBuiltinVariable(
Block &body,
Location loc, spirv::BuiltIn builtin,
758 StringRef prefix, StringRef suffix) {
759 if (
auto varOp = getBuiltinVariable(body, builtin))
765 spirv::GlobalVariableOp newVarOp;
767 case spirv::BuiltIn::NumWorkgroups:
768 case spirv::BuiltIn::WorkgroupSize:
769 case spirv::BuiltIn::WorkgroupId:
770 case spirv::BuiltIn::LocalInvocationId:
771 case spirv::BuiltIn::GlobalInvocationId: {
773 spirv::StorageClass::Input);
774 std::string name = getBuiltinVarName(builtin, prefix, suffix);
776 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
779 case spirv::BuiltIn::SubgroupId:
780 case spirv::BuiltIn::NumSubgroups:
781 case spirv::BuiltIn::SubgroupSize:
782 case spirv::BuiltIn::SubgroupLocalInvocationId: {
785 std::string name = getBuiltinVarName(builtin, prefix, suffix);
787 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
791 emitError(loc,
"unimplemented builtin variable generation for ")
792 << stringifyBuiltIn(builtin);
814 static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
815 unsigned elementCount) {
816 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
817 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
824 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825 auto numElements = cast<spirv::ArrayType>(
826 cast<spirv::StructType>(ptrType.getPointeeType())
829 if (numElements == elementCount)
838 static spirv::GlobalVariableOp
842 if (
auto varOp = getPushConstantVariable(block, elementCount))
846 auto type = getPushConstantStorageType(elementCount, builder, indexType);
847 const char *name =
"__push_constant_var__";
848 return builder.
create<spirv::GlobalVariableOp>(loc, type, name,
862 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
864 FunctionType fnType = funcOp.getFunctionType();
865 if (fnType.getNumResults() > 1)
869 fnType.getNumInputs());
870 for (
const auto &argType :
enumerate(fnType.getInputs())) {
871 auto convertedType = getTypeConverter()->convertType(argType.value());
874 signatureConverter.
addInputs(argType.index(), convertedType);
878 if (fnType.getNumResults() == 1) {
879 resultType = getTypeConverter()->convertType(fnType.getResult(0));
885 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
886 funcOp.getLoc(), funcOp.getName(),
892 for (
const auto &namedAttr : funcOp->getAttrs()) {
893 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
895 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
901 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
913 LogicalResult matchAndRewrite(func::FuncOp funcOp,
915 FunctionType fnType = funcOp.getFunctionType();
918 if (funcOp.isDeclaration()) {
919 LLVM_DEBUG(llvm::dbgs()
920 << fnType <<
" illegal: declarations are unsupported\n");
925 auto newFuncOp = rewriter.
create<func::FuncOp>(funcOp.getLoc(),
926 funcOp.getName(), fnType);
930 Location loc = newFuncOp.getBody().getLoc();
932 Block &entryBlock = newFuncOp.getBlocks().
front();
937 fnType.getInputs().size());
943 size_t newInputNo = 0;
949 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
952 size_t newOpCount = 0;
955 for (
auto [origInputNo, origType] :
enumerate(fnType.getInputs())) {
957 auto origVecType = dyn_cast<VectorType>(origType);
964 oneToNTypeMapping.
addInputs(origInputNo, origType);
977 oneToNTypeMapping.
addInputs(origInputNo, origType);
982 VectorType unrolledType =
985 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
989 loc, origVecType, rewriter.
getZeroAttr(origVecType));
993 loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
1001 result = rewriter.
create<vector::InsertStridedSliceOp>(
1002 loc, dummy, result, offsets, strides);
1003 newTypes.push_back(unrolledType);
1004 unrolledInputNums.push_back(newInputNo);
1009 oneToNTypeMapping.
addInputs(origInputNo, newTypes);
1014 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1016 [&] { newFuncOp.setFunctionType(newFnType); });
1025 for (
auto &[placeholderOp, argIdx] : tmpOps) {
1028 Value replacement = newFuncOp.getArgument(argIdx);
1036 size_t unrolledInputIdx = 0;
1042 if (count >= newOpCount)
1044 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1045 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1047 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1069 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1072 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1076 FunctionType fnType = funcOp.getFunctionType();
1078 fnType.getResults().size());
1085 for (
auto [origResultNo, origType] :
enumerate(fnType.getResults())) {
1087 auto origVecType = dyn_cast<VectorType>(origType);
1089 oneToNTypeMapping.
addInputs(origResultNo, origType);
1090 newOperands.push_back(returnOp.getOperand(origResultNo));
1097 oneToNTypeMapping.
addInputs(origResultNo, origType);
1098 newOperands.push_back(returnOp.getOperand(origResultNo));
1101 VectorType unrolledType =
1106 auto originalShape =
1107 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1110 extractShape.back() = targetShape->back();
1112 Value returnValue = returnOp.getOperand(origResultNo);
1115 Value result = rewriter.
create<vector::ExtractStridedSliceOp>(
1116 loc, returnValue, offsets, extractShape, strides);
1117 if (originalShape.size() > 1) {
1120 rewriter.
create<vector::ExtractOp>(loc, result, extractIndices);
1122 newOperands.push_back(result);
1123 newTypes.push_back(unrolledType);
1125 oneToNTypeMapping.
addInputs(origResultNo, newTypes);
1133 [&] { funcOp.setFunctionType(newFnType); });
1138 rewriter.
create<func::ReturnOp>(loc, newOperands));
1151 spirv::BuiltIn builtin,
1153 StringRef prefix, StringRef suffix) {
1156 op->
emitError(
"expected operation to be within a module-like op");
1160 spirv::GlobalVariableOp varOp =
1162 builtin, integerType, builder, prefix, suffix);
1164 return builder.
create<spirv::LoadOp>(op->
getLoc(), ptr);
1172 unsigned offset,
Type integerType,
1177 op->
emitError(
"expected operation to be within a module-like op");
1181 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1182 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1185 Value offsetOp = builder.
create<spirv::ConstantOp>(
1187 auto addrOp = builder.
create<spirv::AddressOfOp>(loc, varOp);
1188 auto acOp = builder.
create<spirv::AccessChainOp>(
1190 return builder.
create<spirv::LoadOp>(loc, acOp);
1198 int64_t offset,
Type integerType,
1200 assert(indices.size() == strides.size() &&
1201 "must provide indices for all dimensions");
1215 builder.
createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1217 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1219 return linearizedIndex;
1223 MemRefType baseType,
Value basePtr,
1230 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1231 llvm::is_contained(strides, ShapedType::kDynamic) ||
1232 ShapedType::isDynamic(offset)) {
1242 linearizedIndices.push_back(zero);
1244 if (baseType.getRank() == 0) {
1245 linearizedIndices.push_back(zero);
1247 linearizedIndices.push_back(
1248 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1250 return builder.
create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1254 MemRefType baseType,
Value basePtr,
1261 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1262 llvm::is_contained(strides, ShapedType::kDynamic) ||
1263 ShapedType::isDynamic(offset)) {
1271 if (baseType.getRank() == 0) {
1275 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1278 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1279 if (isa<spirv::ArrayType>(pointeeType)) {
1280 linearizedIndices.push_back(linearIndex);
1281 return builder.
create<spirv::AccessChainOp>(loc, basePtr,
1284 return builder.
create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1289 MemRefType baseType,
Value basePtr,
1293 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1307 for (
int i : {4, 3, 2}) {
1316 VectorType srcVectorType = op.getSourceVectorType();
1317 assert(srcVectorType.getRank() == 1);
1318 int64_t vectorSize =
1320 return {vectorSize};
1325 VectorType vectorType = op.getResultVectorType();
1332 std::optional<SmallVector<int64_t>>
1335 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1344 .Case<vector::ReductionOp, vector::TransposeOp>(
1346 .Default([](
Operation *) {
return std::nullopt; });
1380 patterns, vector::VectorTransposeLowering::EltWise);
1391 vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
1392 vector::ReductionOp::getCanonicalizationPatterns(
patterns, context);
1393 vector::TransposeOp::getCanonicalizationPatterns(
patterns, context);
1397 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1399 vector::InsertOp::getCanonicalizationPatterns(
patterns, context);
1400 vector::ExtractOp::getCanonicalizationPatterns(
patterns, context);
1404 vector::BroadcastOp::getCanonicalizationPatterns(
patterns, context);
1405 vector::ShapeCastOp::getCanonicalizationPatterns(
patterns, context);
1434 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1435 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1436 return convertScalarType(this->targetEnv, this->options, scalarType);
1437 if (intType.getWidth() < 8)
1438 return convertSubByteIntegerType(this->options, intType);
1442 addConversion([
this](FloatType floatType) -> std::optional<Type> {
1443 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1444 return convertScalarType(this->targetEnv, this->options, scalarType);
1449 return convertComplexType(this->targetEnv, this->options, complexType);
1453 return convertVectorType(this->targetEnv, this->options, vectorType);
1457 return convertTensorType(this->targetEnv, this->options, tensorType);
1461 return convertMemrefType(this->targetEnv, this->options, memRefType);
1467 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1471 auto cast = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
1477 return ::getIndexType(getContext(), options);
1480 MLIRContext *SPIRVTypeConverter::getContext()
const {
1481 return targetEnv.
getAttr().getContext();
1485 return targetEnv.
allows(capability);
1492 std::unique_ptr<SPIRVConversionTarget>
1494 std::unique_ptr<SPIRVConversionTarget> target(
1498 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1501 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1508 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1512 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1513 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1514 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1515 LLVM_DEBUG(llvm::dbgs()
1516 << op->
getName() <<
" illegal: requiring min version "
1517 << spirv::stringifyVersion(*minVersion) <<
"\n");
1521 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1522 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1523 if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
1524 LLVM_DEBUG(llvm::dbgs()
1525 << op->
getName() <<
" illegal: requiring max version "
1526 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1534 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1535 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1536 extensions.getExtensions())))
1542 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1543 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1544 capabilities.getCapabilities())))
1552 if (llvm::any_of(valueTypes,
1553 [](
Type t) {
return !isa<spirv::SPIRVType>(t); }))
1558 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1559 valueTypes.push_back(globalVar.getType());
1565 for (
Type valueType : valueTypes) {
1566 typeExtensions.clear();
1567 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1568 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1572 typeCapabilities.clear();
1573 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1574 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)