21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
102 for (
Operation &op : llvm::drop_begin(ops))
103 if (
auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
115 uint32_t wordCount = 1 + operands.size();
117 binary.append(operands.begin(), operands.end());
122 : module(module), mlirBuilder(module.
getContext()), options(options) {}
125 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
127 if (failed(module.verifyInvariants()))
132 if (failed(processExtension())) {
135 processMemoryModel();
142 for (
auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
148 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
160 binary.reserve(moduleSize);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
180 os <<
"\n= Value <id> Map =\n\n";
181 for (
auto valueIDPair : valueIDMap) {
182 Value val = valueIDPair.first;
183 os <<
" " << val <<
" "
184 <<
"id = " << valueIDPair.second <<
' ';
186 os <<
"from op '" << op->getName() <<
"'";
187 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
188 Block *block = arg.getOwner();
189 os <<
"from argument of block " << block <<
' ';
201uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
202 auto funcID = funcIDMap.lookup(fnName);
204 funcID = getNextID();
205 funcIDMap[fnName] = funcID;
210void Serializer::processCapability() {
211 for (
auto cap : module.getVceTriple()->getCapabilities())
213 {
static_cast<uint32_t
>(cap)});
216void Serializer::processDebugInfo() {
217 if (!options.emitDebugInfo)
219 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
220 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
221 fileID = getNextID();
222 SmallVector<uint32_t, 16> operands;
223 operands.push_back(fileID);
229LogicalResult Serializer::processExtension() {
230 llvm::SmallVector<uint32_t, 16> extName;
231 llvm::SmallSet<Extension, 4> deducedExts(
232 llvm::from_range, module.getVceTriple()->getExtensions());
233 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
234 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
236 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
237 return module.emitError(
238 "SPV_KHR_non_semantic_info extension not available");
239 deducedExts.insert(nonSemanticInfoExt);
241 for (spirv::Extension ext : deducedExts) {
249void Serializer::processMemoryModel() {
250 StringAttr memoryModelName =
module.getMemoryModelAttrName();
251 auto mm =
static_cast<uint32_t
>(
252 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
255 StringAttr addressingModelName =
module.getAddressingModelAttrName();
256 auto am =
static_cast<uint32_t
>(
257 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
266 if (attrName ==
"fp_fast_math_mode")
267 return "FPFastMathMode";
269 if (attrName ==
"fp_rounding_mode")
270 return "FPRoundingMode";
272 if (attrName ==
"cache_control_load_intel")
273 return "CacheControlLoadINTEL";
274 if (attrName ==
"cache_control_store_intel")
275 return "CacheControlStoreINTEL";
277 return llvm::convertToCamelFromSnakeCase(attrName,
true);
280template <
typename AttrTy,
typename EmitF>
283 StringRef attrName, EmitF emitter) {
284 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
286 return emitError(loc,
"expecting array attribute of ")
287 << attrName <<
" for " << stringifyDecoration(decoration);
289 if (arrayAttr.empty()) {
290 return emitError(loc,
"expecting non-empty array attribute of ")
291 << attrName <<
" for " << stringifyDecoration(decoration);
293 for (
Attribute attr : arrayAttr.getValue()) {
294 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
295 if (!cacheControlAttr) {
296 return emitError(loc,
"expecting array attribute of ")
297 << attrName <<
" for " << stringifyDecoration(decoration);
301 if (failed(emitter(cacheControlAttr)))
307LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
308 Decoration decoration,
311 switch (decoration) {
312 case spirv::Decoration::LinkageAttributes: {
315 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
316 auto linkageName = linkageAttr.getLinkageName();
317 auto linkageType = linkageAttr.getLinkageType().getValue();
321 args.push_back(
static_cast<uint32_t
>(linkageType));
324 case spirv::Decoration::FPFastMathMode:
325 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
326 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
329 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
330 << stringifyDecoration(decoration);
331 case spirv::Decoration::FPRoundingMode:
332 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
333 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
336 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
337 << stringifyDecoration(decoration);
338 case spirv::Decoration::Binding:
339 case spirv::Decoration::DescriptorSet:
340 case spirv::Decoration::Location:
341 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
342 args.push_back(intAttr.getValue().getZExtValue());
345 return emitError(loc,
"expected integer attribute for ")
346 << stringifyDecoration(decoration);
347 case spirv::Decoration::BuiltIn:
348 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
349 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
351 args.push_back(
static_cast<uint32_t
>(*enumVal));
355 << stringifyDecoration(decoration) <<
" decoration attribute "
356 << strAttr.getValue();
358 return emitError(loc,
"expected string attribute for ")
359 << stringifyDecoration(decoration);
360 case spirv::Decoration::Aliased:
361 case spirv::Decoration::AliasedPointer:
362 case spirv::Decoration::Flat:
363 case spirv::Decoration::NonReadable:
364 case spirv::Decoration::NonWritable:
365 case spirv::Decoration::NoPerspective:
366 case spirv::Decoration::NoSignedWrap:
367 case spirv::Decoration::NoUnsignedWrap:
368 case spirv::Decoration::RelaxedPrecision:
369 case spirv::Decoration::Restrict:
370 case spirv::Decoration::RestrictPointer:
371 case spirv::Decoration::NoContraction:
372 case spirv::Decoration::Constant:
373 case spirv::Decoration::Block:
374 case spirv::Decoration::Invariant:
375 case spirv::Decoration::Patch:
378 if (isa<UnitAttr, DecorationAttr>(attr))
381 "expected unit attribute or decoration attribute for ")
382 << stringifyDecoration(decoration);
383 case spirv::Decoration::CacheControlLoadINTEL:
385 loc, decoration, attr,
"CacheControlLoadINTEL",
386 [&](CacheControlLoadINTELAttr attr) {
387 unsigned cacheLevel = attr.getCacheLevel();
388 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
389 return emitDecoration(
390 resultID, decoration,
391 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
393 case spirv::Decoration::CacheControlStoreINTEL:
395 loc, decoration, attr,
"CacheControlStoreINTEL",
396 [&](CacheControlStoreINTELAttr attr) {
397 unsigned cacheLevel = attr.getCacheLevel();
398 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
399 return emitDecoration(
400 resultID, decoration,
401 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
404 return emitError(loc,
"unhandled decoration ")
405 << stringifyDecoration(decoration);
407 return emitDecoration(resultID, decoration, args);
410LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
411 NamedAttribute attr) {
412 StringRef attrName = attr.
getName().strref();
414 std::optional<Decoration> decoration =
415 spirv::symbolizeDecoration(decorationName);
418 loc,
"non-argument attributes expected to have snake-case-ified "
419 "decoration name, unhandled attribute with name : ")
422 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
425LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
426 assert(!name.empty() &&
"unexpected empty string for OpName");
427 if (!options.emitSymbolName)
430 SmallVector<uint32_t, 4> nameOperands;
431 nameOperands.push_back(resultID);
438LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
442 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
448LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
452 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
457LogicalResult Serializer::processMemberDecoration(
462 static_cast<uint32_t
>(memberDecoration.
decoration)});
478bool Serializer::isInterfaceStructPtrType(Type type)
const {
479 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
480 switch (ptrType.getStorageClass()) {
481 case spirv::StorageClass::PhysicalStorageBuffer:
482 case spirv::StorageClass::PushConstant:
483 case spirv::StorageClass::StorageBuffer:
484 case spirv::StorageClass::Uniform:
485 return isa<spirv::StructType>(ptrType.getPointeeType());
493LogicalResult Serializer::processType(Location loc, Type type,
498 return processTypeImpl(loc, type, typeID, serializationCtx);
502Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
514 IntegerType::SignednessSemantics::Signless);
517 typeID = getTypeID(type);
521 typeID = getNextID();
522 SmallVector<uint32_t, 4> operands;
524 operands.push_back(typeID);
525 auto typeEnum = spirv::Opcode::OpTypeVoid;
526 bool deferSerialization =
false;
528 if ((isa<FunctionType>(type) &&
529 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
531 (isa<GraphType>(type) &&
533 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
534 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
535 deferSerialization, serializationCtx))) {
536 if (deferSerialization)
539 typeIDMap[type] = typeID;
543 if (recursiveStructInfos.count(type) != 0) {
546 for (
auto &ptrInfo : recursiveStructInfos[type]) {
549 SmallVector<uint32_t, 4> ptrOperands;
550 ptrOperands.push_back(ptrInfo.pointerTypeID);
551 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
552 ptrOperands.push_back(typeIDMap[type]);
558 recursiveStructInfos[type].clear();
564 return emitError(loc,
"failed to process type: ") << type;
567LogicalResult Serializer::prepareBasicType(
568 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
569 SmallVectorImpl<uint32_t> &operands,
bool &deferSerialization,
571 deferSerialization =
false;
573 if (isVoidType(type)) {
574 typeEnum = spirv::Opcode::OpTypeVoid;
578 if (
auto intType = dyn_cast<IntegerType>(type)) {
579 if (intType.getWidth() == 1) {
580 typeEnum = spirv::Opcode::OpTypeBool;
584 typeEnum = spirv::Opcode::OpTypeInt;
585 operands.push_back(intType.getWidth());
590 operands.push_back(intType.isSigned() ? 1 : 0);
594 if (
auto floatType = dyn_cast<FloatType>(type)) {
595 typeEnum = spirv::Opcode::OpTypeFloat;
596 operands.push_back(floatType.getWidth());
597 if (floatType.isBF16()) {
598 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
603 if (
auto vectorType = dyn_cast<VectorType>(type)) {
604 uint32_t elementTypeID = 0;
605 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
606 serializationCtx))) {
609 typeEnum = spirv::Opcode::OpTypeVector;
610 operands.push_back(elementTypeID);
611 operands.push_back(vectorType.getNumElements());
615 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
616 typeEnum = spirv::Opcode::OpTypeImage;
617 uint32_t sampledTypeID = 0;
618 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
621 llvm::append_values(operands, sampledTypeID,
622 static_cast<uint32_t
>(imageType.getDim()),
623 static_cast<uint32_t
>(imageType.getDepthInfo()),
624 static_cast<uint32_t
>(imageType.getArrayedInfo()),
625 static_cast<uint32_t
>(imageType.getSamplingInfo()),
626 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
627 static_cast<uint32_t
>(imageType.getImageFormat()));
631 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
632 typeEnum = spirv::Opcode::OpTypeArray;
633 uint32_t elementTypeID = 0;
634 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
635 serializationCtx))) {
638 operands.push_back(elementTypeID);
639 if (
auto elementCountID = prepareConstantInt(
640 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
641 operands.push_back(elementCountID);
643 return processTypeDecoration(loc, arrayType, resultID);
646 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
647 uint32_t pointeeTypeID = 0;
648 spirv::StructType pointeeStruct =
649 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
652 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
657 SmallVector<uint32_t, 2> forwardPtrOperands;
658 forwardPtrOperands.push_back(resultID);
659 forwardPtrOperands.push_back(
660 static_cast<uint32_t
>(ptrType.getStorageClass()));
663 spirv::Opcode::OpTypeForwardPointer,
675 deferSerialization =
true;
679 recursiveStructInfos[structType].push_back(
680 {resultID, ptrType.getStorageClass()});
682 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
687 typeEnum = spirv::Opcode::OpTypePointer;
688 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
689 operands.push_back(pointeeTypeID);
694 if (isInterfaceStructPtrType(ptrType)) {
695 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
696 if (!structType.hasDecoration(spirv::Decoration::Block))
697 if (
failed(emitDecoration(getTypeID(pointeeStruct),
698 spirv::Decoration::Block)))
699 return emitError(loc,
"cannot decorate ")
700 << pointeeStruct <<
" with Block decoration";
706 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
707 uint32_t elementTypeID = 0;
708 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
709 elementTypeID, serializationCtx))) {
712 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
713 operands.push_back(elementTypeID);
714 return processTypeDecoration(loc, runtimeArrayType, resultID);
717 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
718 typeEnum = spirv::Opcode::OpTypeSampledImage;
719 uint32_t imageTypeID = 0;
721 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
724 operands.push_back(imageTypeID);
728 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
729 if (structType.isIdentified()) {
730 if (
failed(processName(resultID, structType.getIdentifier())))
732 serializationCtx.insert(structType.getIdentifier());
735 bool hasOffset = structType.hasOffset();
736 for (
auto elementIndex :
737 llvm::seq<uint32_t>(0, structType.getNumElements())) {
738 uint32_t elementTypeID = 0;
739 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
740 elementTypeID, serializationCtx))) {
743 operands.push_back(elementTypeID);
745 auto intType = IntegerType::get(structType.getContext(), 32);
747 spirv::StructType::MemberDecorationInfo offsetDecoration{
748 elementIndex, spirv::Decoration::Offset,
749 IntegerAttr::get(intType,
750 structType.getMemberOffset(elementIndex))};
751 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
752 return emitError(loc,
"cannot decorate ")
753 << elementIndex <<
"-th member of " << structType
754 <<
" with its offset";
758 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
759 structType.getMemberDecorations(memberDecorations);
761 for (
auto &memberDecoration : memberDecorations) {
762 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
763 return emitError(loc,
"cannot decorate ")
764 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
765 <<
"-th member of " << structType <<
" with "
766 << stringifyDecoration(memberDecoration.
decoration);
770 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
771 structType.getStructDecorations(structDecorations);
773 for (spirv::StructType::StructDecorationInfo &structDecoration :
775 if (
failed(processDecorationAttr(loc, resultID,
776 structDecoration.decoration,
777 structDecoration.decorationValue))) {
778 return emitError(loc,
"cannot decorate struct ")
779 << structType <<
" with "
780 << stringifyDecoration(structDecoration.decoration);
784 typeEnum = spirv::Opcode::OpTypeStruct;
786 if (structType.isIdentified())
787 serializationCtx.remove(structType.getIdentifier());
792 if (
auto cooperativeMatrixType =
793 dyn_cast<spirv::CooperativeMatrixType>(type)) {
794 uint32_t elementTypeID = 0;
795 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
796 elementTypeID, serializationCtx))) {
799 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
800 auto getConstantOp = [&](uint32_t id) {
801 auto attr = IntegerAttr::get(IntegerType::get(type.
getContext(), 32),
id);
802 return prepareConstantInt(loc, attr);
805 operands, elementTypeID,
806 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
807 getConstantOp(cooperativeMatrixType.getRows()),
808 getConstantOp(cooperativeMatrixType.getColumns()),
809 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
813 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
814 uint32_t elementTypeID = 0;
815 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
816 serializationCtx))) {
819 typeEnum = spirv::Opcode::OpTypeMatrix;
820 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
824 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
825 uint32_t elementTypeID = 0;
827 uint32_t shapeID = 0;
829 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
830 elementTypeID, serializationCtx))) {
833 if (tensorArmType.hasRank()) {
834 ArrayRef<int64_t> dims = tensorArmType.getShape();
836 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
841 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
842 if (rank > 0 && shaped) {
843 auto I32Type = IntegerType::get(type.
getContext(), 32);
846 SmallVector<uint64_t, 1> index(rank);
847 shapeID = prepareDenseElementsConstant(
849 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
852 shapeID = prepareArrayConstant(
854 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
861 typeEnum = spirv::Opcode::OpTypeTensorARM;
862 operands.push_back(elementTypeID);
865 operands.push_back(rankID);
868 operands.push_back(shapeID);
873 return emitError(loc,
"unhandled type in serialization: ") << type;
877Serializer::prepareFunctionType(Location loc, FunctionType type,
878 spirv::Opcode &typeEnum,
879 SmallVectorImpl<uint32_t> &operands) {
880 typeEnum = spirv::Opcode::OpTypeFunction;
881 assert(type.getNumResults() <= 1 &&
882 "serialization supports only a single return value");
883 uint32_t resultID = 0;
885 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
889 operands.push_back(resultID);
890 for (
auto &res : type.getInputs()) {
891 uint32_t argTypeID = 0;
892 if (
failed(processType(loc, res, argTypeID))) {
895 operands.push_back(argTypeID);
901Serializer::prepareGraphType(Location loc, GraphType type,
902 spirv::Opcode &typeEnum,
903 SmallVectorImpl<uint32_t> &operands) {
904 typeEnum = spirv::Opcode::OpTypeGraphARM;
905 assert(type.getNumResults() >= 1 &&
906 "serialization requires at least a return value");
908 operands.push_back(type.getNumInputs());
910 for (Type argType : type.getInputs()) {
911 uint32_t argTypeID = 0;
912 if (
failed(processType(loc, argType, argTypeID)))
914 operands.push_back(argTypeID);
917 for (Type resType : type.getResults()) {
918 uint32_t resTypeID = 0;
919 if (
failed(processType(loc, resType, resTypeID)))
921 operands.push_back(resTypeID);
931uint32_t Serializer::prepareConstant(Location loc, Type constType,
932 Attribute valueAttr) {
933 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
940 if (
auto id = getConstantID(valueAttr)) {
945 if (
failed(processType(loc, constType, typeID))) {
949 uint32_t resultID = 0;
950 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
951 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
952 SmallVector<uint64_t, 4> index(rank);
953 resultID = prepareDenseElementsConstant(loc, constType, attr,
955 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
956 resultID = prepareArrayConstant(loc, constType, arrayAttr);
960 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
964 constIDMap[valueAttr] = resultID;
968uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
971 if (
failed(processType(loc, constType, typeID))) {
975 uint32_t resultID = getNextID();
976 SmallVector<uint32_t, 4> operands = {typeID, resultID};
977 operands.reserve(attr.size() + 2);
978 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
979 for (Attribute elementAttr : attr) {
980 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
981 operands.push_back(elementID);
986 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
995Serializer::prepareDenseElementsConstant(Location loc, Type constType,
996 DenseElementsAttr valueAttr,
int dim,
997 MutableArrayRef<uint64_t> index) {
998 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
999 assert(dim <= shapedType.getRank());
1000 if (shapedType.getRank() == dim) {
1001 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1002 return attr.getType().getElementType().isInteger(1)
1003 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1004 : prepareConstantInt(loc,
1005 attr.getValues<IntegerAttr>()[index]);
1007 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1008 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1013 uint32_t typeID = 0;
1014 if (
failed(processType(loc, constType, typeID))) {
1018 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1019 uint32_t resultID = getNextID();
1020 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1021 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1022 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1023 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1024 if (!innerShape.empty())
1032 if (isa<spirv::CooperativeMatrixType>(constType)) {
1036 "cannot serialize a non-splat value for a cooperative matrix type");
1041 operands.reserve(3);
1044 if (
auto elementID = prepareDenseElementsConstant(
1045 loc, elementType, valueAttr, shapedType.getRank(), index)) {
1046 operands.push_back(elementID);
1050 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1052 {typeID, resultID});
1055 operands.reserve(numberOfConstituents + 2);
1056 for (
int i = 0; i < numberOfConstituents; ++i) {
1058 if (
auto elementID = prepareDenseElementsConstant(
1059 loc, elementType, valueAttr, dim + 1, index)) {
1060 operands.push_back(elementID);
1066 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1072uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1074 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1075 return prepareConstantFp(loc, floatAttr, isSpec);
1077 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1078 return prepareConstantBool(loc, boolAttr, isSpec);
1080 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1081 return prepareConstantInt(loc, intAttr, isSpec);
1087uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1091 if (
auto id = getConstantID(boolAttr)) {
1097 uint32_t typeID = 0;
1098 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1102 auto resultID = getNextID();
1104 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1105 : spirv::Opcode::OpConstantTrue)
1106 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1107 : spirv::Opcode::OpConstantFalse);
1111 constIDMap[boolAttr] = resultID;
1116uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1120 if (
auto id = getConstantID(intAttr)) {
1126 uint32_t typeID = 0;
1127 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1131 auto resultID = getNextID();
1132 APInt value = intAttr.getValue();
1133 unsigned bitwidth = value.getBitWidth();
1134 bool isSigned = intAttr.getType().isSignedInteger();
1136 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1149 word =
static_cast<int32_t
>(value.getSExtValue());
1151 word =
static_cast<uint32_t
>(value.getZExtValue());
1163 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1165 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1168 {typeID, resultID, words.word1, words.word2});
1171 std::string valueStr;
1172 llvm::raw_string_ostream rss(valueStr);
1173 value.print(rss,
false);
1176 << bitwidth <<
"-bit integer literal: " << valueStr;
1182 constIDMap[intAttr] = resultID;
1187uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1188 IntegerAttr intAttr) {
1190 if (uint32_t
id = getGraphConstantARMId(intAttr)) {
1195 uint32_t typeID = 0;
1196 if (
failed(processType(loc, graphConstType, typeID))) {
1200 uint32_t resultID = getNextID();
1201 APInt value = intAttr.getValue();
1202 unsigned bitwidth = value.getBitWidth();
1203 if (bitwidth > 32) {
1204 emitError(loc,
"Too wide attribute for OpGraphConstantARM: ")
1205 << bitwidth <<
" bits";
1208 bool isSigned = value.isSignedIntN(bitwidth);
1212 word =
static_cast<int32_t
>(value.getSExtValue());
1214 word =
static_cast<uint32_t
>(value.getZExtValue());
1217 {typeID, resultID, word});
1218 graphConstIDMap[intAttr] = resultID;
1222uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1226 if (
auto id = getConstantID(floatAttr)) {
1232 uint32_t typeID = 0;
1233 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1237 auto resultID = getNextID();
1238 APFloat value = floatAttr.getValue();
1239 const llvm::fltSemantics *semantics = &value.getSemantics();
1242 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1244 if (semantics == &APFloat::IEEEsingle()) {
1245 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1247 }
else if (semantics == &APFloat::IEEEdouble()) {
1251 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1253 {typeID, resultID, words.word1, words.word2});
1254 }
else if (semantics == &APFloat::IEEEhalf() ||
1255 semantics == &APFloat::BFloat()) {
1257 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1260 std::string valueStr;
1261 llvm::raw_string_ostream rss(valueStr);
1265 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1270 constIDMap[floatAttr] = resultID;
1279 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1280 return typedAttr.getType();
1283 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1290uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1293 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1294 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1298 uint32_t typeID = 0;
1299 if (
failed(processType(loc, resultType, typeID))) {
1307 auto compositeType = dyn_cast<CompositeType>(resultType);
1310 Type elementType = compositeType.getElementType(0);
1312 uint32_t constandID;
1313 if (elementType == valueType) {
1314 constandID = prepareConstant(loc, elementType, valueAttr);
1316 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1319 uint32_t resultID = getNextID();
1320 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1322 {typeID, resultID});
1325 spirv::Opcode::OpConstantCompositeReplicateEXT,
1326 {typeID, resultID, constandID});
1329 constCompositeReplicateIDMap[valueTypePair] = resultID;
1337uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1338 if (uint32_t
id = getBlockID(block))
1340 return blockIDMap[block] = getNextID();
1344void Serializer::printBlock(
Block *block, raw_ostream &os) {
1345 os <<
"block " << block <<
" (id = ";
1346 if (uint32_t
id = getBlockID(block))
1355Serializer::processBlock(
Block *block,
bool omitLabel,
1357 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1358 LLVM_DEBUG(block->
print(llvm::dbgs()));
1359 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1361 uint32_t blockID = getOrCreateBlockID(block);
1362 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1369 if (
failed(emitPhiForBlockArguments(block)))
1379 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1382 emitMerge =
nullptr;
1385 uint32_t blockID = getNextID();
1391 for (Operation &op : llvm::drop_end(*block)) {
1392 if (
failed(processOperation(&op)))
1400 if (
failed(processOperation(&block->
back())))
1406LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1412 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1419 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1421 auto *terminator = mlirPredecessor->getTerminator();
1422 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1423 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1424 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1433 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1434 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1435 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1436 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1437 }
else if (
auto branchCondOp =
1438 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1439 std::optional<OperandRange> blockOperands;
1440 if (branchCondOp.getTrueTarget() == block) {
1441 blockOperands = branchCondOp.getTrueTargetOperands();
1443 assert(branchCondOp.getFalseTarget() == block);
1444 blockOperands = branchCondOp.getFalseTargetOperands();
1447 assert(!blockOperands->empty() &&
1448 "expected non-empty block operand range");
1449 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1451 return terminator->emitError(
"unimplemented terminator for Phi creation");
1454 llvm::dbgs() <<
" block arguments:\n";
1455 for (Value v : predecessors.back().second)
1456 llvm::dbgs() <<
" " << v <<
"\n";
1461 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1465 uint32_t phiTypeID = 0;
1468 uint32_t phiID = getNextID();
1470 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1471 << arg <<
" (id = " << phiID <<
")\n");
1474 SmallVector<uint32_t, 8> phiArgs;
1475 phiArgs.push_back(phiTypeID);
1476 phiArgs.push_back(phiID);
1478 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1479 Value value = predecessors[predIndex].second[argIndex];
1480 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1481 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1482 <<
") value " << value <<
' ');
1484 uint32_t valueId = getValueID(value);
1488 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1489 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1492 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1494 phiArgs.push_back(valueId);
1496 phiArgs.push_back(predBlockId);
1500 valueIDMap[arg] = phiID;
1510LogicalResult Serializer::encodeExtensionInstruction(
1511 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1512 ArrayRef<uint32_t> operands) {
1514 auto &setID = extendedInstSetIDMap[extensionSetName];
1516 setID = getNextID();
1517 SmallVector<uint32_t, 16> importOperands;
1518 importOperands.push_back(setID);
1526 if (operands.size() < 2) {
1527 return op->
emitError(
"extended instructions must have a result encoding");
1529 SmallVector<uint32_t, 8> extInstOperands;
1530 extInstOperands.reserve(operands.size() + 2);
1531 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1532 extInstOperands.push_back(setID);
1533 extInstOperands.push_back(extensionOpcode);
1534 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1540LogicalResult Serializer::processOperation(Operation *opInst) {
1541 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1546 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1547 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1548 .Case([&](spirv::BranchConditionalOp op) {
1549 return processBranchConditionalOp(op);
1551 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1552 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1553 return processConstantCompositeReplicateOp(op);
1555 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1556 .Case([&](spirv::GraphARMOp op) {
return processGraphARMOp(op); })
1557 .Case([&](spirv::GraphEntryPointARMOp op) {
1558 return processGraphEntryPointARMOp(op);
1560 .Case([&](spirv::GraphOutputsARMOp op) {
1561 return processGraphOutputsARMOp(op);
1563 .Case([&](spirv::GlobalVariableOp op) {
1564 return processGlobalVariableOp(op);
1566 .Case([&](spirv::GraphConstantARMOp op) {
1567 return processGraphConstantARMOp(op);
1569 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1570 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1571 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1572 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1573 .Case([&](spirv::SpecConstantCompositeOp op) {
1574 return processSpecConstantCompositeOp(op);
1576 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1577 return processSpecConstantCompositeReplicateOp(op);
1579 .Case([&](spirv::SpecConstantOperationOp op) {
1580 return processSpecConstantOperationOp(op);
1582 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1583 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1588 [&](Operation *op) {
return dispatchToAutogenSerialization(op); });
1591LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1592 StringRef extInstSet,
1594 SmallVector<uint32_t, 4> operands;
1595 Location loc = op->
getLoc();
1597 uint32_t resultID = 0;
1599 uint32_t resultTypeID = 0;
1602 operands.push_back(resultTypeID);
1604 resultID = getNextID();
1605 operands.push_back(resultID);
1606 valueIDMap[op->
getResult(0)] = resultID;
1610 operands.push_back(getValueID(operand));
1612 if (
failed(emitDebugLine(functionBody, loc)))
1615 if (extInstSet.empty()) {
1619 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1625 if (
failed(processDecoration(loc, resultID, attr)))
1633LogicalResult Serializer::emitDecoration(uint32_t
target,
1634 spirv::Decoration decoration,
1635 ArrayRef<uint32_t> params) {
1636 uint32_t wordCount = 3 + params.size();
1637 llvm::append_values(
1640 static_cast<uint32_t
>(decoration));
1641 llvm::append_range(decorations, params);
1645LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1647 if (!options.emitDebugInfo)
1650 if (lastProcessedWasMergeInst) {
1651 lastProcessedWasMergeInst =
false;
1655 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1658 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
static bool isZeroValue(Value val)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::function_ref< Fn > function_ref
Attribute decorationValue