33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/MathExtras.h"
43 #define DEBUG_TYPE "mlir-spirv-conversion"
53 static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
54 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
55 if (vecType.isScalable()) {
56 LLVM_DEBUG(llvm::dbgs()
57 <<
"--scalable vectors are not supported -> BAIL\n");
64 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
68 if (!maybeShapeRatio) {
69 LLVM_DEBUG(llvm::dbgs()
70 <<
"--could not compute integral shape ratio -> BAIL\n");
73 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
74 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
77 LLVM_DEBUG(llvm::dbgs()
78 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
88 template <
typename LabelT>
89 static LogicalResult checkExtensionRequirements(
92 for (
const auto &ors : candidates) {
98 for (spirv::Extension ext : ors)
99 extStrings.push_back(spirv::stringifyExtension(ext));
101 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
102 << llvm::join(extStrings,
", ")
103 <<
"] but none allowed in target environment\n";
116 template <
typename LabelT>
117 static LogicalResult checkCapabilityRequirements(
120 for (
const auto &ors : candidates) {
121 if (targetEnv.
allows(ors))
126 for (spirv::Capability cap : ors)
127 capStrings.push_back(spirv::stringifyCapability(cap));
129 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
130 << llvm::join(capStrings,
", ")
131 <<
"] but none allowed in target environment\n";
140 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
141 switch (storageClass) {
142 case spirv::StorageClass::PhysicalStorageBuffer:
143 case spirv::StorageClass::PushConstant:
144 case spirv::StorageClass::StorageBuffer:
145 case spirv::StorageClass::Uniform:
155 wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
156 auto structType = needsExplicitLayout(storageClass)
168 return cast<spirv::ScalarType>(
174 static std::optional<int64_t>
176 if (isa<spirv::ScalarType>(type)) {
189 if (
auto complexType = dyn_cast<ComplexType>(type)) {
190 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
193 return 2 * *elementSize;
196 if (
auto vecType = dyn_cast<VectorType>(type)) {
197 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
200 return vecType.getNumElements() * *elementSize;
203 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
208 if (!memRefType.hasStaticShape() ||
215 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
219 if (memRefType.getRank() == 0)
222 auto dims = memRefType.getShape();
223 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224 ShapedType::isDynamic(offset) ||
225 llvm::is_contained(strides, ShapedType::kDynamic))
228 int64_t memrefSize = -1;
229 for (
const auto &shape :
enumerate(dims))
230 memrefSize =
std::max(memrefSize, shape.value() * strides[shape.index()]);
232 return (offset + memrefSize) * *elementSize;
235 if (
auto tensorType = dyn_cast<TensorType>(type)) {
236 if (!tensorType.hasStaticShape())
239 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
243 int64_t size = *elementSize;
244 for (
auto shape : tensorType.getShape())
258 std::optional<spirv::StorageClass> storageClass = {}) {
262 type.getExtensions(extensions, storageClass);
263 type.getCapabilities(capabilities, storageClass);
266 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
267 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 if (!
options.emulateLT32BitScalarTypes)
277 LLVM_DEBUG(llvm::dbgs()
279 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
283 if (
auto floatType = dyn_cast<FloatType>(type)) {
284 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
288 auto intType = cast<IntegerType>(type);
289 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
291 intType.getSignedness());
304 if (type.getWidth() > 8) {
305 LLVM_DEBUG(llvm::dbgs() <<
"not a subbyte type\n");
309 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
313 if (!llvm::isPowerOf2_32(type.getWidth())) {
314 LLVM_DEBUG(llvm::dbgs()
315 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
320 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
322 type.getSignedness());
329 convertIndexElementType(ShapedType type,
331 Type indexType = dyn_cast<IndexType>(type.getElementType());
342 std::optional<spirv::StorageClass> storageClass = {}) {
343 type = cast<VectorType>(convertIndexElementType(type,
options));
344 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
348 auto intType = dyn_cast<IntegerType>(type.getElementType());
350 LLVM_DEBUG(llvm::dbgs()
352 <<
" illegal: cannot convert non-scalar element type\n");
356 Type elementType = convertSubByteIntegerType(
options, intType);
360 if (type.getRank() <= 1 && type.getNumElements() == 1)
363 if (type.getNumElements() > 4) {
364 LLVM_DEBUG(llvm::dbgs()
365 << type <<
" illegal: > 4-element unimplemented\n");
372 if (type.getRank() <= 1 && type.getNumElements() == 1)
373 return convertScalarType(targetEnv,
options, scalarType, storageClass);
376 LLVM_DEBUG(llvm::dbgs()
377 << type <<
" illegal: not a valid composite type\n");
384 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
385 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
388 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
389 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
393 convertScalarType(targetEnv,
options, scalarType, storageClass);
402 std::optional<spirv::StorageClass> storageClass = {}) {
403 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
405 LLVM_DEBUG(llvm::dbgs()
406 << type <<
" illegal: cannot convert non-scalar element type\n");
411 convertScalarType(targetEnv,
options, scalarType, storageClass);
414 if (elementType != type.getElementType()) {
415 LLVM_DEBUG(llvm::dbgs()
416 << type <<
" illegal: complex type emulation unsupported\n");
433 if (!type.hasStaticShape()) {
434 LLVM_DEBUG(llvm::dbgs()
435 << type <<
" illegal: dynamic shape unimplemented\n");
439 type = cast<TensorType>(convertIndexElementType(type,
options));
440 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
442 LLVM_DEBUG(llvm::dbgs()
443 << type <<
" illegal: cannot convert non-scalar element type\n");
447 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
448 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
449 if (!scalarSize || !tensorSize) {
450 LLVM_DEBUG(llvm::dbgs()
451 << type <<
" illegal: cannot deduce element count\n");
455 int64_t arrayElemCount = *tensorSize / *scalarSize;
456 if (arrayElemCount == 0) {
457 LLVM_DEBUG(llvm::dbgs()
458 << type <<
" illegal: cannot handle zero-element tensors\n");
462 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
465 std::optional<int64_t> arrayElemSize =
466 getTypeNumBytes(
options, arrayElemType);
467 if (!arrayElemSize) {
468 LLVM_DEBUG(llvm::dbgs()
469 << type <<
" illegal: cannot deduce converted element size\n");
479 spirv::StorageClass storageClass) {
480 unsigned numBoolBits =
options.boolNumBits;
481 if (numBoolBits != 8) {
482 LLVM_DEBUG(llvm::dbgs()
483 <<
"using non-8-bit storage for bool types unimplemented");
486 auto elementType = dyn_cast<spirv::ScalarType>(
491 convertScalarType(targetEnv,
options, elementType, storageClass);
494 std::optional<int64_t> arrayElemSize =
495 getTypeNumBytes(
options, arrayElemType);
496 if (!arrayElemSize) {
497 LLVM_DEBUG(llvm::dbgs()
498 << type <<
" illegal: cannot deduce converted element size\n");
502 if (!type.hasStaticShape()) {
505 if (targetEnv.
allows(spirv::Capability::Kernel))
507 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
511 return wrapInStructAndGetPointer(arrayType, storageClass);
514 if (type.getNumElements() == 0) {
515 LLVM_DEBUG(llvm::dbgs()
516 << type <<
" illegal: zero-element memrefs are not supported\n");
520 int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
522 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
524 if (targetEnv.
allows(spirv::Capability::Kernel))
526 return wrapInStructAndGetPointer(arrayType, storageClass);
532 spirv::StorageClass storageClass) {
533 IntegerType elementType = cast<IntegerType>(type.getElementType());
534 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
537 int64_t arrayElemSize = *getTypeNumBytes(
options, arrayElemType);
539 if (!type.hasStaticShape()) {
542 if (targetEnv.
allows(spirv::Capability::Kernel))
544 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
548 return wrapInStructAndGetPointer(arrayType, storageClass);
551 if (type.getNumElements() == 0) {
552 LLVM_DEBUG(llvm::dbgs()
553 << type <<
" illegal: zero-element memrefs are not supported\n");
560 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
562 if (targetEnv.
allows(spirv::Capability::Kernel))
564 return wrapInStructAndGetPointer(arrayType, storageClass);
570 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
575 <<
" illegal: expected memory space to be a SPIR-V storage class "
576 "attribute; please use MemorySpaceToStorageClassConverter to map "
577 "numeric memory spaces beforehand\n");
580 spirv::StorageClass storageClass = attr.getValue();
582 if (isa<IntegerType>(type.getElementType())) {
583 if (type.getElementTypeBitWidth() == 1)
584 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
585 if (type.getElementTypeBitWidth() < 8)
586 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
590 Type elementType = type.getElementType();
591 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
593 convertVectorType(targetEnv,
options, vecType, storageClass);
594 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
596 convertComplexType(targetEnv,
options, complexType, storageClass);
597 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
599 convertScalarType(targetEnv,
options, scalarType, storageClass);
600 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
601 type = cast<MemRefType>(convertIndexElementType(type,
options));
602 arrayElemType = type.getElementType();
607 <<
" unhandled: can only convert scalar or vector element type\n");
613 std::optional<int64_t> arrayElemSize =
614 getTypeNumBytes(
options, arrayElemType);
615 if (!arrayElemSize) {
616 LLVM_DEBUG(llvm::dbgs()
617 << type <<
" illegal: cannot deduce converted element size\n");
621 if (!type.hasStaticShape()) {
624 if (targetEnv.
allows(spirv::Capability::Kernel))
626 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
630 return wrapInStructAndGetPointer(arrayType, storageClass);
633 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
635 LLVM_DEBUG(llvm::dbgs()
636 << type <<
" illegal: cannot deduce element count\n");
640 if (*memrefSize == 0) {
641 LLVM_DEBUG(llvm::dbgs()
642 << type <<
" illegal: zero-element memrefs are not supported\n");
647 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
649 if (targetEnv.
allows(spirv::Capability::Kernel))
651 return wrapInStructAndGetPointer(arrayType, storageClass);
675 if (inputs.size() != 1) {
676 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
679 Value input = inputs.front();
682 if (!isa<IntegerType>(type)) {
683 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
686 auto inputType = cast<IntegerType>(input.
getType());
688 auto scalarType = dyn_cast<spirv::ScalarType>(type);
690 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
697 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
698 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
704 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
705 return builder.
create<spirv::IEqualOp>(loc, input, one);
711 scalarType.getExtensions(exts);
712 scalarType.getCapabilities(caps);
713 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
714 failed(checkExtensionRequirements(type, targetEnv, exts))) {
715 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
723 return builder.
create<spirv::SConvertOp>(loc, type, input);
725 return builder.
create<spirv::UConvertOp>(loc, type, input);
732 static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
733 spirv::BuiltIn builtin) {
736 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
737 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
738 spirv::SPIRVDialect::getAttributeName(
739 spirv::Decoration::BuiltIn))) {
740 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
741 if (varBuiltIn && *varBuiltIn == builtin) {
750 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
752 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
756 static spirv::GlobalVariableOp
757 getOrInsertBuiltinVariable(
Block &body,
Location loc, spirv::BuiltIn builtin,
759 StringRef prefix, StringRef suffix) {
760 if (
auto varOp = getBuiltinVariable(body, builtin))
766 spirv::GlobalVariableOp newVarOp;
768 case spirv::BuiltIn::NumWorkgroups:
769 case spirv::BuiltIn::WorkgroupSize:
770 case spirv::BuiltIn::WorkgroupId:
771 case spirv::BuiltIn::LocalInvocationId:
772 case spirv::BuiltIn::GlobalInvocationId: {
774 spirv::StorageClass::Input);
775 std::string name = getBuiltinVarName(builtin, prefix, suffix);
777 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
780 case spirv::BuiltIn::SubgroupId:
781 case spirv::BuiltIn::NumSubgroups:
782 case spirv::BuiltIn::SubgroupSize: {
785 std::string name = getBuiltinVarName(builtin, prefix, suffix);
787 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
791 emitError(loc,
"unimplemented builtin variable generation for ")
792 << stringifyBuiltIn(builtin);
814 static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
815 unsigned elementCount) {
816 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
817 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
824 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825 auto numElements = cast<spirv::ArrayType>(
826 cast<spirv::StructType>(ptrType.getPointeeType())
829 if (numElements == elementCount)
838 static spirv::GlobalVariableOp
842 if (
auto varOp = getPushConstantVariable(block, elementCount))
846 auto type = getPushConstantStorageType(elementCount, builder, indexType);
847 const char *name =
"__push_constant_var__";
848 return builder.
create<spirv::GlobalVariableOp>(loc, type, name,
862 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
864 FunctionType fnType = funcOp.getFunctionType();
865 if (fnType.getNumResults() > 1)
869 fnType.getNumInputs());
870 for (
const auto &argType :
enumerate(fnType.getInputs())) {
871 auto convertedType = getTypeConverter()->convertType(argType.value());
874 signatureConverter.
addInputs(argType.index(), convertedType);
878 if (fnType.getNumResults() == 1) {
879 resultType = getTypeConverter()->convertType(fnType.getResult(0));
885 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
886 funcOp.getLoc(), funcOp.getName(),
892 for (
const auto &namedAttr : funcOp->getAttrs()) {
893 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
895 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
901 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
913 LogicalResult matchAndRewrite(func::FuncOp funcOp,
915 FunctionType fnType = funcOp.getFunctionType();
918 if (funcOp.isDeclaration()) {
919 LLVM_DEBUG(llvm::dbgs()
920 << fnType <<
" illegal: declarations are unsupported\n");
925 auto newFuncOp = rewriter.
create<func::FuncOp>(funcOp.getLoc(),
926 funcOp.getName(), fnType);
930 Location loc = newFuncOp.getBody().getLoc();
932 Block &entryBlock = newFuncOp.getBlocks().
front();
942 size_t newInputNo = 0;
948 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
951 size_t newOpCount = 0;
954 for (
auto [origInputNo, origType] :
enumerate(fnType.getInputs())) {
956 auto origVecType = dyn_cast<VectorType>(origType);
963 oneToNTypeMapping.
addInputs(origInputNo, origType);
976 oneToNTypeMapping.
addInputs(origInputNo, origType);
981 VectorType unrolledType =
984 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
988 loc, origVecType, rewriter.
getZeroAttr(origVecType));
992 loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
1000 result = rewriter.
create<vector::InsertStridedSliceOp>(
1001 loc, dummy, result, offsets, strides);
1002 newTypes.push_back(unrolledType);
1003 unrolledInputNums.push_back(newInputNo);
1008 oneToNTypeMapping.
addInputs(origInputNo, newTypes);
1013 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1015 [&] { newFuncOp.setFunctionType(newFnType); });
1024 size_t unrolledInputIdx = 0;
1029 for (
auto [operandIdx, operandVal] :
llvm::enumerate(op.getOperands())) {
1030 Operation *operandOp = operandVal.getDefiningOp();
1031 if (
auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1032 size_t idx = operandIdx;
1034 curOp.
setOperand(idx, newFuncOp.getArgument(it->second));
1041 if (count >= newOpCount)
1043 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1044 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1046 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1068 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1071 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1075 FunctionType fnType = funcOp.getFunctionType();
1083 for (
auto [origResultNo, origType] :
enumerate(fnType.getResults())) {
1085 auto origVecType = dyn_cast<VectorType>(origType);
1087 oneToNTypeMapping.
addInputs(origResultNo, origType);
1088 newOperands.push_back(returnOp.getOperand(origResultNo));
1095 oneToNTypeMapping.
addInputs(origResultNo, origType);
1096 newOperands.push_back(returnOp.getOperand(origResultNo));
1099 VectorType unrolledType =
1104 auto originalShape =
1105 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1108 extractShape.back() = targetShape->back();
1110 Value returnValue = returnOp.getOperand(origResultNo);
1113 Value result = rewriter.
create<vector::ExtractStridedSliceOp>(
1114 loc, returnValue, offsets, extractShape, strides);
1115 if (originalShape.size() > 1) {
1118 rewriter.
create<vector::ExtractOp>(loc, result, extractIndices);
1120 newOperands.push_back(result);
1121 newTypes.push_back(unrolledType);
1123 oneToNTypeMapping.
addInputs(origResultNo, newTypes);
1131 [&] { funcOp.setFunctionType(newFnType); });
1136 rewriter.
create<func::ReturnOp>(loc, newOperands));
1149 spirv::BuiltIn builtin,
1151 StringRef prefix, StringRef suffix) {
1154 op->
emitError(
"expected operation to be within a module-like op");
1158 spirv::GlobalVariableOp varOp =
1160 builtin, integerType, builder, prefix, suffix);
1162 return builder.
create<spirv::LoadOp>(op->
getLoc(), ptr);
1170 unsigned offset,
Type integerType,
1175 op->
emitError(
"expected operation to be within a module-like op");
1179 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1180 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1183 Value offsetOp = builder.
create<spirv::ConstantOp>(
1185 auto addrOp = builder.
create<spirv::AddressOfOp>(loc, varOp);
1186 auto acOp = builder.
create<spirv::AccessChainOp>(
1188 return builder.
create<spirv::LoadOp>(loc, acOp);
1196 int64_t offset,
Type integerType,
1198 assert(indices.size() == strides.size() &&
1199 "must provide indices for all dimensions");
1213 builder.
createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1215 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1217 return linearizedIndex;
1221 MemRefType baseType,
Value basePtr,
1229 llvm::is_contained(strides, ShapedType::kDynamic) ||
1230 ShapedType::isDynamic(offset)) {
1240 linearizedIndices.push_back(zero);
1242 if (baseType.getRank() == 0) {
1243 linearizedIndices.push_back(zero);
1245 linearizedIndices.push_back(
1246 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1248 return builder.
create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1252 MemRefType baseType,
Value basePtr,
1260 llvm::is_contained(strides, ShapedType::kDynamic) ||
1261 ShapedType::isDynamic(offset)) {
1269 if (baseType.getRank() == 0) {
1273 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1276 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1277 if (isa<spirv::ArrayType>(pointeeType)) {
1278 linearizedIndices.push_back(linearIndex);
1279 return builder.
create<spirv::AccessChainOp>(loc, basePtr,
1282 return builder.
create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1287 MemRefType baseType,
Value basePtr,
1291 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1305 for (
int i : {4, 3, 2}) {
1314 VectorType srcVectorType = op.getSourceVectorType();
1315 assert(srcVectorType.getRank() == 1);
1316 int64_t vectorSize =
1318 return {vectorSize};
1323 VectorType vectorType = op.getResultVectorType();
1330 std::optional<SmallVector<int64_t>>
1333 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1342 .Case<vector::ReductionOp, vector::TransposeOp>(
1344 .Default([](
Operation *) {
return std::nullopt; });
1378 vector::VectorTransposeLowering::EltWise);
1391 vector::ReductionOp::getCanonicalizationPatterns(
patterns, context);
1392 vector::TransposeOp::getCanonicalizationPatterns(
patterns, context);
1398 vector::InsertOp::getCanonicalizationPatterns(
patterns, context);
1399 vector::ExtractOp::getCanonicalizationPatterns(
patterns, context);
1403 vector::BroadcastOp::getCanonicalizationPatterns(
patterns, context);
1404 vector::ShapeCastOp::getCanonicalizationPatterns(
patterns, context);
1433 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1434 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1435 return convertScalarType(this->targetEnv, this->options, scalarType);
1436 if (intType.getWidth() < 8)
1437 return convertSubByteIntegerType(this->options, intType);
1442 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1443 return convertScalarType(this->targetEnv, this->options, scalarType);
1448 return convertComplexType(this->targetEnv, this->options, complexType);
1452 return convertVectorType(this->targetEnv, this->options, vectorType);
1456 return convertTensorType(this->targetEnv, this->options, tensorType);
1460 return convertMemrefType(this->targetEnv, this->options, memRefType);
1466 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1470 auto cast = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
1476 return ::getIndexType(getContext(), options);
1479 MLIRContext *SPIRVTypeConverter::getContext()
const {
1480 return targetEnv.
getAttr().getContext();
1484 return targetEnv.
allows(capability);
1491 std::unique_ptr<SPIRVConversionTarget>
1493 std::unique_ptr<SPIRVConversionTarget> target(
1497 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1500 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1507 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1511 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1512 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1513 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1514 LLVM_DEBUG(llvm::dbgs()
1515 << op->
getName() <<
" illegal: requiring min version "
1516 << spirv::stringifyVersion(*minVersion) <<
"\n");
1520 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1521 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1522 if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
1523 LLVM_DEBUG(llvm::dbgs()
1524 << op->
getName() <<
" illegal: requiring max version "
1525 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1533 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1534 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1535 extensions.getExtensions())))
1541 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1542 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1543 capabilities.getCapabilities())))
1551 if (llvm::any_of(valueTypes,
1552 [](
Type t) {
return !isa<spirv::SPIRVType>(t); }))
1557 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1558 valueTypes.push_back(globalVar.getType());
1564 for (
Type valueType : valueTypes) {
1565 typeExtensions.clear();
1566 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1567 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1571 typeCapabilities.clear();
1572 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1573 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)