21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
102 for (
Operation &op : llvm::drop_begin(ops))
103 if (
auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
115 uint32_t wordCount = 1 + operands.size();
117 binary.append(operands.begin(), operands.end());
122 : module(module), mlirBuilder(module.
getContext()), options(options) {}
125 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
127 if (failed(module.verifyInvariants()))
132 if (failed(processExtension())) {
135 processMemoryModel();
142 for (
auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
148 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
160 binary.reserve(moduleSize);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176 binary.append(graphsDebugInfo.begin(), graphsDebugInfo.end());
181 os <<
"\n= Value <id> Map =\n\n";
182 for (
auto valueIDPair : valueIDMap) {
183 Value val = valueIDPair.first;
184 os <<
" " << val <<
" "
185 <<
"id = " << valueIDPair.second <<
' ';
187 os <<
"from op '" << op->getName() <<
"'";
188 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
189 Block *block = arg.getOwner();
190 os <<
"from argument of block " << block <<
' ';
202uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
203 auto funcID = funcIDMap.lookup(fnName);
205 funcID = getNextID();
206 funcIDMap[fnName] = funcID;
211void Serializer::processCapability() {
212 for (
auto cap : module.getVceTriple()->getCapabilities())
214 {
static_cast<uint32_t
>(cap)});
217void Serializer::addLongCompositesCapability() {
218 if (longCompositesEmitted)
220 longCompositesEmitted =
true;
221 auto vceTriple =
module.getVceTriple();
222 if (!llvm::is_contained(vceTriple->getCapabilities(),
223 spirv::Capability::LongCompositesINTEL))
225 capabilities, spirv::Opcode::OpCapability,
226 {
static_cast<uint32_t
>(spirv::Capability::LongCompositesINTEL)});
227 if (!llvm::is_contained(vceTriple->getExtensions(),
228 spirv::Extension::SPV_INTEL_long_composites)) {
229 SmallVector<uint32_t, 8> extName;
232 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
237void Serializer::encodeInstructionWithContinuationInto(
238 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
239 ArrayRef<uint32_t> operands) {
245 std::optional<spirv::Opcode> continuationOp =
247 assert(continuationOp &&
"op is not a splittable composite/struct opcode");
251 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
252 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
256 addLongCompositesCapability();
259void Serializer::processDebugInfo() {
260 if (!options.emitDebugInfo)
262 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
263 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
264 fileID = getNextID();
265 SmallVector<uint32_t, 16> operands;
266 operands.push_back(fileID);
272LogicalResult Serializer::processExtension() {
273 llvm::SmallVector<uint32_t, 16> extName;
274 llvm::SmallSet<Extension, 4> deducedExts(
275 llvm::from_range, module.getVceTriple()->getExtensions());
276 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
277 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
279 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
280 return module.emitError(
281 "SPV_KHR_non_semantic_info extension not available");
282 deducedExts.insert(nonSemanticInfoExt);
284 for (spirv::Extension ext : deducedExts) {
292void Serializer::processMemoryModel() {
293 StringAttr memoryModelName =
module.getMemoryModelAttrName();
294 auto mm =
static_cast<uint32_t
>(
295 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
298 StringAttr addressingModelName =
module.getAddressingModelAttrName();
299 auto am =
static_cast<uint32_t
>(
300 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
309 if (attrName ==
"fp_fast_math_mode")
310 return "FPFastMathMode";
312 if (attrName ==
"fp_rounding_mode")
313 return "FPRoundingMode";
315 if (attrName ==
"cache_control_load_intel")
316 return "CacheControlLoadINTEL";
317 if (attrName ==
"cache_control_store_intel")
318 return "CacheControlStoreINTEL";
320 return llvm::convertToCamelFromSnakeCase(attrName,
true);
323template <
typename AttrTy,
typename EmitF>
326 StringRef attrName, EmitF emitter) {
327 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
329 return emitError(loc,
"expecting array attribute of ")
330 << attrName <<
" for " << stringifyDecoration(decoration);
332 if (arrayAttr.empty()) {
333 return emitError(loc,
"expecting non-empty array attribute of ")
334 << attrName <<
" for " << stringifyDecoration(decoration);
336 for (
Attribute attr : arrayAttr.getValue()) {
337 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
338 if (!cacheControlAttr) {
339 return emitError(loc,
"expecting array attribute of ")
340 << attrName <<
" for " << stringifyDecoration(decoration);
344 if (failed(emitter(cacheControlAttr)))
350LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
351 Decoration decoration,
354 switch (decoration) {
355 case spirv::Decoration::LinkageAttributes: {
358 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
359 auto linkageName = linkageAttr.getLinkageName();
360 auto linkageType = linkageAttr.getLinkageType().getValue();
364 args.push_back(
static_cast<uint32_t
>(linkageType));
367 case spirv::Decoration::FPFastMathMode:
368 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
369 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
372 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
373 << stringifyDecoration(decoration);
374 case spirv::Decoration::FPRoundingMode:
375 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
376 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
379 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
380 << stringifyDecoration(decoration);
381 case spirv::Decoration::Binding:
382 case spirv::Decoration::DescriptorSet:
383 case spirv::Decoration::Location:
384 case spirv::Decoration::Index:
385 case spirv::Decoration::Offset:
386 case spirv::Decoration::XfbBuffer:
387 case spirv::Decoration::XfbStride:
388 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
389 args.push_back(intAttr.getValue().getZExtValue());
392 return emitError(loc,
"expected integer attribute for ")
393 << stringifyDecoration(decoration);
394 case spirv::Decoration::BuiltIn:
395 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
396 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
398 args.push_back(
static_cast<uint32_t
>(*enumVal));
402 << stringifyDecoration(decoration) <<
" decoration attribute "
403 << strAttr.getValue();
405 return emitError(loc,
"expected string attribute for ")
406 << stringifyDecoration(decoration);
407 case spirv::Decoration::Aliased:
408 case spirv::Decoration::AliasedPointer:
409 case spirv::Decoration::Flat:
410 case spirv::Decoration::NonReadable:
411 case spirv::Decoration::NonWritable:
412 case spirv::Decoration::NoPerspective:
413 case spirv::Decoration::NoSignedWrap:
414 case spirv::Decoration::NoUnsignedWrap:
415 case spirv::Decoration::RelaxedPrecision:
416 case spirv::Decoration::Restrict:
417 case spirv::Decoration::RestrictPointer:
418 case spirv::Decoration::NoContraction:
419 case spirv::Decoration::Constant:
420 case spirv::Decoration::Block:
421 case spirv::Decoration::Invariant:
422 case spirv::Decoration::Patch:
423 case spirv::Decoration::Coherent:
426 if (isa<UnitAttr, DecorationAttr>(attr))
429 "expected unit attribute or decoration attribute for ")
430 << stringifyDecoration(decoration);
431 case spirv::Decoration::CacheControlLoadINTEL:
433 loc, decoration, attr,
"CacheControlLoadINTEL",
434 [&](CacheControlLoadINTELAttr attr) {
435 unsigned cacheLevel = attr.getCacheLevel();
436 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
437 return emitDecoration(
438 resultID, decoration,
439 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
441 case spirv::Decoration::CacheControlStoreINTEL:
443 loc, decoration, attr,
"CacheControlStoreINTEL",
444 [&](CacheControlStoreINTELAttr attr) {
445 unsigned cacheLevel = attr.getCacheLevel();
446 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
447 return emitDecoration(
448 resultID, decoration,
449 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
451 case spirv::Decoration::AlignmentId:
452 case spirv::Decoration::MaxByteOffsetId:
453 case spirv::Decoration::CounterBuffer: {
454 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
456 return emitError(loc,
"expected symbol reference for ")
457 << stringifyDecoration(decoration);
458 StringRef symName = symRef.getValue();
459 uint32_t operandID = getVariableID(symName);
461 operandID = getSpecConstID(symName);
463 return emitError(loc,
"could not find <id> for symbol '")
464 << symName <<
"' referenced by "
465 << stringifyDecoration(decoration);
466 return emitDecorationId(resultID, decoration, {operandID});
469 return emitError(loc,
"unhandled decoration ")
470 << stringifyDecoration(decoration);
472 return emitDecoration(resultID, decoration, args);
475LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
476 NamedAttribute attr) {
477 StringRef attrName = attr.
getName().strref();
479 std::optional<Decoration> decoration =
480 spirv::symbolizeDecoration(decorationName);
483 loc,
"non-argument attributes expected to have snake-case-ified "
484 "decoration name, unhandled attribute with name : ")
487 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
490LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
491 assert(!name.empty() &&
"unexpected empty string for OpName");
492 if (!options.emitSymbolName)
495 SmallVector<uint32_t, 4> nameOperands;
496 nameOperands.push_back(resultID);
503LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
507 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
513LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
517 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
522LogicalResult Serializer::processMemberDecoration(
527 static_cast<uint32_t
>(memberDecoration.
decoration)});
543bool Serializer::isInterfaceStructPtrType(Type type)
const {
544 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
545 switch (ptrType.getStorageClass()) {
546 case spirv::StorageClass::PhysicalStorageBuffer:
547 case spirv::StorageClass::PushConstant:
548 case spirv::StorageClass::StorageBuffer:
549 case spirv::StorageClass::Uniform:
550 return isa<spirv::StructType>(ptrType.getPointeeType());
558LogicalResult Serializer::processType(Location loc, Type type,
563 return processTypeImpl(loc, type, typeID, serializationCtx);
567Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
579 IntegerType::SignednessSemantics::Signless);
582 typeID = getTypeID(type);
586 typeID = getNextID();
587 SmallVector<uint32_t, 4> operands;
589 operands.push_back(typeID);
590 auto typeEnum = spirv::Opcode::OpTypeVoid;
591 bool deferSerialization =
false;
593 if ((isa<FunctionType>(type) &&
594 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
596 (isa<GraphType>(type) &&
598 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
599 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
600 deferSerialization, serializationCtx))) {
601 if (deferSerialization)
604 typeIDMap[type] = typeID;
606 if (typeEnum == spirv::Opcode::OpTypeStruct)
607 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
612 if (recursiveStructInfos.count(type) != 0) {
615 for (
auto &ptrInfo : recursiveStructInfos[type]) {
618 SmallVector<uint32_t, 4> ptrOperands;
619 ptrOperands.push_back(ptrInfo.pointerTypeID);
620 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
621 ptrOperands.push_back(typeIDMap[type]);
627 recursiveStructInfos[type].clear();
633 return emitError(loc,
"failed to process type: ") << type;
636LogicalResult Serializer::prepareBasicType(
637 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
638 SmallVectorImpl<uint32_t> &operands,
bool &deferSerialization,
640 deferSerialization =
false;
642 if (isVoidType(type)) {
643 typeEnum = spirv::Opcode::OpTypeVoid;
647 if (
auto intType = dyn_cast<IntegerType>(type)) {
648 if (intType.getWidth() == 1) {
649 typeEnum = spirv::Opcode::OpTypeBool;
653 typeEnum = spirv::Opcode::OpTypeInt;
654 operands.push_back(intType.getWidth());
659 operands.push_back(intType.isSigned() ? 1 : 0);
663 if (
auto floatType = dyn_cast<FloatType>(type)) {
664 typeEnum = spirv::Opcode::OpTypeFloat;
665 operands.push_back(floatType.getWidth());
666 if (floatType.isBF16()) {
667 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
669 if (floatType.isF8E4M3FN()) {
671 static_cast<uint32_t
>(spirv::FPEncoding::Float8E4M3EXT));
673 if (floatType.isF8E5M2()) {
675 static_cast<uint32_t
>(spirv::FPEncoding::Float8E5M2EXT));
681 if (
auto vectorType = dyn_cast<VectorType>(type)) {
682 uint32_t elementTypeID = 0;
683 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
684 serializationCtx))) {
687 typeEnum = spirv::Opcode::OpTypeVector;
688 operands.push_back(elementTypeID);
689 operands.push_back(vectorType.getNumElements());
693 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
694 typeEnum = spirv::Opcode::OpTypeImage;
695 uint32_t sampledTypeID = 0;
696 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
699 llvm::append_values(operands, sampledTypeID,
700 static_cast<uint32_t
>(imageType.getDim()),
701 static_cast<uint32_t
>(imageType.getDepthInfo()),
702 static_cast<uint32_t
>(imageType.getArrayedInfo()),
703 static_cast<uint32_t
>(imageType.getSamplingInfo()),
704 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
705 static_cast<uint32_t
>(imageType.getImageFormat()));
709 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
710 typeEnum = spirv::Opcode::OpTypeArray;
711 uint32_t elementTypeID = 0;
712 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
713 serializationCtx))) {
716 operands.push_back(elementTypeID);
717 if (
auto elementCountID = prepareConstantInt(
718 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
719 operands.push_back(elementCountID);
721 return processTypeDecoration(loc, arrayType, resultID);
724 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
725 uint32_t pointeeTypeID = 0;
726 spirv::StructType pointeeStruct =
727 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
730 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
735 SmallVector<uint32_t, 2> forwardPtrOperands;
736 forwardPtrOperands.push_back(resultID);
737 forwardPtrOperands.push_back(
738 static_cast<uint32_t
>(ptrType.getStorageClass()));
741 spirv::Opcode::OpTypeForwardPointer,
753 deferSerialization =
true;
757 recursiveStructInfos[structType].push_back(
758 {resultID, ptrType.getStorageClass()});
760 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
765 typeEnum = spirv::Opcode::OpTypePointer;
766 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
767 operands.push_back(pointeeTypeID);
772 if (isInterfaceStructPtrType(ptrType)) {
773 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
774 if (!structType.hasDecoration(spirv::Decoration::Block))
775 if (
failed(emitDecoration(getTypeID(pointeeStruct),
776 spirv::Decoration::Block)))
777 return emitError(loc,
"cannot decorate ")
778 << pointeeStruct <<
" with Block decoration";
784 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
785 uint32_t elementTypeID = 0;
786 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
787 elementTypeID, serializationCtx))) {
790 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
791 operands.push_back(elementTypeID);
792 return processTypeDecoration(loc, runtimeArrayType, resultID);
795 if (isa<spirv::SamplerType>(type)) {
796 typeEnum = spirv::Opcode::OpTypeSampler;
800 if (isa<spirv::NamedBarrierType>(type)) {
801 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
805 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
806 typeEnum = spirv::Opcode::OpTypeSampledImage;
807 uint32_t imageTypeID = 0;
809 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
812 operands.push_back(imageTypeID);
816 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
817 if (structType.isIdentified()) {
818 if (
failed(processName(resultID, structType.getIdentifier())))
820 serializationCtx.insert(structType.getIdentifier());
823 bool hasOffset = structType.hasOffset();
824 for (
auto elementIndex :
825 llvm::seq<uint32_t>(0, structType.getNumElements())) {
826 uint32_t elementTypeID = 0;
827 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
828 elementTypeID, serializationCtx))) {
831 operands.push_back(elementTypeID);
833 auto intType = IntegerType::get(structType.getContext(), 32);
835 spirv::StructType::MemberDecorationInfo offsetDecoration{
836 elementIndex, spirv::Decoration::Offset,
837 IntegerAttr::get(intType,
838 structType.getMemberOffset(elementIndex))};
839 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
840 return emitError(loc,
"cannot decorate ")
841 << elementIndex <<
"-th member of " << structType
842 <<
" with its offset";
846 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
847 structType.getMemberDecorations(memberDecorations);
849 for (
auto &memberDecoration : memberDecorations) {
850 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
851 return emitError(loc,
"cannot decorate ")
852 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
853 <<
"-th member of " << structType <<
" with "
854 << stringifyDecoration(memberDecoration.
decoration);
858 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
859 structType.getStructDecorations(structDecorations);
861 for (spirv::StructType::StructDecorationInfo &structDecoration :
863 if (
failed(processDecorationAttr(loc, resultID,
864 structDecoration.decoration,
865 structDecoration.decorationValue))) {
866 return emitError(loc,
"cannot decorate struct ")
867 << structType <<
" with "
868 << stringifyDecoration(structDecoration.decoration);
872 typeEnum = spirv::Opcode::OpTypeStruct;
874 if (structType.isIdentified())
875 serializationCtx.remove(structType.getIdentifier());
880 if (
auto cooperativeMatrixType =
881 dyn_cast<spirv::CooperativeMatrixType>(type)) {
882 uint32_t elementTypeID = 0;
883 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
884 elementTypeID, serializationCtx))) {
887 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
888 auto getConstantOp = [&](uint32_t id) {
889 auto attr = IntegerAttr::get(IntegerType::get(type.
getContext(), 32),
id);
890 return prepareConstantInt(loc, attr);
893 operands, elementTypeID,
894 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
895 getConstantOp(cooperativeMatrixType.getRows()),
896 getConstantOp(cooperativeMatrixType.getColumns()),
897 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
901 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
902 uint32_t elementTypeID = 0;
903 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
904 serializationCtx))) {
907 typeEnum = spirv::Opcode::OpTypeMatrix;
908 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
912 if (
auto tensorArmType = dyn_cast<TensorArmType>(type)) {
913 uint32_t elementTypeID = 0;
915 uint32_t shapeID = 0;
917 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
918 elementTypeID, serializationCtx))) {
921 if (tensorArmType.hasRank()) {
922 ArrayRef<int64_t> dims = tensorArmType.getShape();
924 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
929 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
930 if (rank > 0 && shaped) {
931 auto I32Type = IntegerType::get(type.
getContext(), 32);
934 SmallVector<uint64_t, 1> index(rank);
935 shapeID = prepareDenseElementsConstant(
937 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
940 shapeID = prepareArrayConstant(
942 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
949 typeEnum = spirv::Opcode::OpTypeTensorARM;
950 operands.push_back(elementTypeID);
953 operands.push_back(rankID);
956 operands.push_back(shapeID);
961 return emitError(loc,
"unhandled type in serialization: ") << type;
965Serializer::prepareFunctionType(Location loc, FunctionType type,
966 spirv::Opcode &typeEnum,
967 SmallVectorImpl<uint32_t> &operands) {
968 typeEnum = spirv::Opcode::OpTypeFunction;
969 assert(type.getNumResults() <= 1 &&
970 "serialization supports only a single return value");
971 uint32_t resultID = 0;
973 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
977 operands.push_back(resultID);
978 for (
auto &res : type.getInputs()) {
979 uint32_t argTypeID = 0;
980 if (
failed(processType(loc, res, argTypeID))) {
983 operands.push_back(argTypeID);
989Serializer::prepareGraphType(Location loc, GraphType type,
990 spirv::Opcode &typeEnum,
991 SmallVectorImpl<uint32_t> &operands) {
992 typeEnum = spirv::Opcode::OpTypeGraphARM;
993 assert(type.getNumResults() >= 1 &&
994 "serialization requires at least a return value");
996 operands.push_back(type.getNumInputs());
998 for (Type argType : type.getInputs()) {
999 uint32_t argTypeID = 0;
1000 if (
failed(processType(loc, argType, argTypeID)))
1002 operands.push_back(argTypeID);
1005 for (Type resType : type.getResults()) {
1006 uint32_t resTypeID = 0;
1007 if (
failed(processType(loc, resType, resTypeID)))
1009 operands.push_back(resTypeID);
1019uint32_t Serializer::prepareConstant(Location loc, Type constType,
1020 Attribute valueAttr) {
1021 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
1028 if (
auto id = getConstantID(valueAttr)) {
1032 uint32_t typeID = 0;
1033 if (
failed(processType(loc, constType, typeID))) {
1037 uint32_t resultID = 0;
1038 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1039 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1040 SmallVector<uint64_t, 4> index(rank);
1041 resultID = prepareDenseElementsConstant(loc, constType, attr,
1043 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1044 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1047 if (resultID == 0) {
1048 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
1052 constIDMap[valueAttr] = resultID;
1056uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1058 uint32_t typeID = 0;
1059 if (
failed(processType(loc, constType, typeID))) {
1063 uint32_t resultID = getNextID();
1064 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1065 operands.reserve(attr.size() + 2);
1066 spirv::CompositeType compositeType = cast<spirv::CompositeType>(constType);
1067 for (
auto [idx, elementAttr] : llvm::enumerate(attr)) {
1068 if (uint32_t elementID = prepareConstant(
1070 operands.push_back(elementID);
1075 encodeInstructionWithContinuationInto(
1076 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1084Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1085 DenseElementsAttr valueAttr,
int dim,
1086 MutableArrayRef<uint64_t> index) {
1087 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
1088 assert(dim <= shapedType.getRank());
1089 if (shapedType.getRank() == dim) {
1090 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1091 return attr.getType().getElementType().isInteger(1)
1092 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1093 : prepareConstantInt(loc,
1094 attr.getValues<IntegerAttr>()[index]);
1096 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1097 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1102 uint32_t typeID = 0;
1103 if (
failed(processType(loc, constType, typeID))) {
1107 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1108 uint32_t resultID = getNextID();
1109 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1110 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1111 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1112 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1113 if (!innerShape.empty())
1121 if (isa<spirv::CooperativeMatrixType>(constType)) {
1125 "cannot serialize a non-splat value for a cooperative matrix type");
1130 operands.reserve(3);
1133 if (
auto elementID = prepareDenseElementsConstant(
1134 loc, elementType, valueAttr, shapedType.getRank(), index)) {
1135 operands.push_back(elementID);
1139 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1141 {typeID, resultID});
1144 operands.reserve(numberOfConstituents + 2);
1145 for (
int i = 0; i < numberOfConstituents; ++i) {
1147 if (
auto elementID = prepareDenseElementsConstant(
1148 loc, elementType, valueAttr, dim + 1, index)) {
1149 operands.push_back(elementID);
1155 encodeInstructionWithContinuationInto(
1156 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1161uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1163 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1164 return prepareConstantFp(loc, floatAttr, isSpec);
1166 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1167 return prepareConstantBool(loc, boolAttr, isSpec);
1169 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1170 return prepareConstantInt(loc, intAttr, isSpec);
1176uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1180 if (
auto id = getConstantID(boolAttr)) {
1186 uint32_t typeID = 0;
1187 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1191 auto resultID = getNextID();
1193 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1194 : spirv::Opcode::OpConstantTrue)
1195 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1196 : spirv::Opcode::OpConstantFalse);
1200 constIDMap[boolAttr] = resultID;
1205uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1209 if (
auto id = getConstantID(intAttr)) {
1215 uint32_t typeID = 0;
1216 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1220 auto resultID = getNextID();
1221 APInt value = intAttr.getValue();
1222 unsigned bitwidth = value.getBitWidth();
1223 bool isSigned = intAttr.getType().isSignedInteger();
1225 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1238 word =
static_cast<int32_t
>(value.getSExtValue());
1240 word =
static_cast<uint32_t
>(value.getZExtValue());
1252 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1254 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1257 {typeID, resultID, words.word1, words.word2});
1260 std::string valueStr;
1261 llvm::raw_string_ostream rss(valueStr);
1262 value.print(rss,
false);
1265 << bitwidth <<
"-bit integer literal: " << valueStr;
1271 constIDMap[intAttr] = resultID;
1276uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1277 IntegerAttr intAttr) {
1279 if (uint32_t
id = getGraphConstantARMId(intAttr)) {
1284 uint32_t typeID = 0;
1285 if (
failed(processType(loc, graphConstType, typeID))) {
1289 uint32_t resultID = getNextID();
1290 APInt value = intAttr.getValue();
1291 unsigned bitwidth = value.getBitWidth();
1292 if (bitwidth > 32) {
1293 emitError(loc,
"Too wide attribute for OpGraphConstantARM: ")
1294 << bitwidth <<
" bits";
1297 bool isSigned = value.isSignedIntN(bitwidth);
1301 word =
static_cast<int32_t
>(value.getSExtValue());
1303 word =
static_cast<uint32_t
>(value.getZExtValue());
1306 {typeID, resultID, word});
1307 graphConstIDMap[intAttr] = resultID;
1311uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1315 if (
auto id = getConstantID(floatAttr)) {
1321 uint32_t typeID = 0;
1322 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1326 auto resultID = getNextID();
1327 APFloat value = floatAttr.getValue();
1328 const llvm::fltSemantics *semantics = &value.getSemantics();
1331 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1333 if (semantics == &APFloat::IEEEsingle()) {
1334 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1336 }
else if (semantics == &APFloat::IEEEdouble()) {
1340 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1342 {typeID, resultID, words.word1, words.word2});
1343 }
else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1344 &APFloat::Float8E4M3FN(),
1345 &APFloat::Float8E5M2()},
1348 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1351 std::string valueStr;
1352 llvm::raw_string_ostream rss(valueStr);
1356 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1361 constIDMap[floatAttr] = resultID;
1370 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1371 return typedAttr.getType();
1374 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1381uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1384 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1385 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1389 uint32_t typeID = 0;
1390 if (
failed(processType(loc, resultType, typeID))) {
1398 auto compositeType = dyn_cast<CompositeType>(resultType);
1403 uint32_t constandID;
1404 if (elementType == valueType) {
1405 constandID = prepareConstant(loc, elementType, valueAttr);
1407 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1410 uint32_t resultID = getNextID();
1411 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1413 {typeID, resultID});
1416 spirv::Opcode::OpConstantCompositeReplicateEXT,
1417 {typeID, resultID, constandID});
1420 constCompositeReplicateIDMap[valueTypePair] = resultID;
1428uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1429 if (uint32_t
id = getBlockID(block))
1431 return blockIDMap[block] = getNextID();
1435void Serializer::printBlock(
Block *block, raw_ostream &os) {
1436 os <<
"block " << block <<
" (id = ";
1437 if (uint32_t
id = getBlockID(block))
1446Serializer::processBlock(
Block *block,
bool omitLabel,
1448 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1449 LLVM_DEBUG(block->
print(llvm::dbgs()));
1450 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1452 uint32_t blockID = getOrCreateBlockID(block);
1453 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1460 if (
failed(emitPhiForBlockArguments(block)))
1470 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1473 emitMerge =
nullptr;
1476 uint32_t blockID = getNextID();
1482 for (Operation &op : llvm::drop_end(*block)) {
1483 if (
failed(processOperation(&op)))
1491 if (
failed(processOperation(&block->
back())))
1497LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1503 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1510 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1512 auto *terminator = mlirPredecessor->getTerminator();
1513 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1514 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1515 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1524 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1525 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1526 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1527 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1528 }
else if (
auto branchCondOp =
1529 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1530 std::optional<OperandRange> blockOperands;
1531 if (branchCondOp.getTrueTarget() == block) {
1532 blockOperands = branchCondOp.getTrueTargetOperands();
1534 assert(branchCondOp.getFalseTarget() == block);
1535 blockOperands = branchCondOp.getFalseTargetOperands();
1537 assert(!blockOperands->empty() &&
1538 "expected non-empty block operand range");
1539 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1540 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1541 std::optional<OperandRange> blockOperands;
1542 if (block == switchOp.getDefaultTarget()) {
1543 blockOperands = switchOp.getDefaultOperands();
1545 SuccessorRange targets = switchOp.getTargets();
1546 auto it = llvm::find(targets, block);
1547 assert(it != targets.end());
1548 size_t index = std::distance(targets.begin(), it);
1549 blockOperands = switchOp.getTargetOperands(index);
1551 assert(!blockOperands->empty() &&
1552 "expected non-empty block operand range");
1553 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1555 return terminator->emitError(
"unimplemented terminator for Phi creation");
1558 llvm::dbgs() <<
" block arguments:\n";
1559 for (Value v : predecessors.back().second)
1560 llvm::dbgs() <<
" " << v <<
"\n";
1565 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1569 uint32_t phiTypeID = 0;
1572 uint32_t phiID = getNextID();
1574 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1575 << arg <<
" (id = " << phiID <<
")\n");
1578 SmallVector<uint32_t, 8> phiArgs;
1579 phiArgs.push_back(phiTypeID);
1580 phiArgs.push_back(phiID);
1582 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1583 Value value = predecessors[predIndex].second[argIndex];
1584 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1585 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1586 <<
") value " << value <<
' ');
1588 uint32_t valueId = getValueID(value);
1592 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1593 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1596 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1598 phiArgs.push_back(valueId);
1600 phiArgs.push_back(predBlockId);
1604 valueIDMap[arg] = phiID;
1614LogicalResult Serializer::encodeExtensionInstruction(
1615 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1616 ArrayRef<uint32_t> operands, SmallVectorImpl<uint32_t> &binary) {
1618 auto &setID = extendedInstSetIDMap[extensionSetName];
1620 setID = getNextID();
1621 SmallVector<uint32_t, 16> importOperands;
1622 importOperands.push_back(setID);
1630 if (operands.size() < 2) {
1631 return op->
emitError(
"extended instructions must have a result encoding");
1633 SmallVector<uint32_t, 8> extInstOperands;
1634 extInstOperands.reserve(operands.size() + 2);
1635 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1636 extInstOperands.push_back(setID);
1637 extInstOperands.push_back(extensionOpcode);
1638 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1643LogicalResult Serializer::encodeExtensionInstruction(
1644 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1645 ArrayRef<uint32_t> operands) {
1646 if (
failed(encodeExtensionInstruction(op, extensionSetName, extensionOpcode,
1647 operands, functionBody)))
1650 if (extensionSetName ==
extTosa)
1651 updateTosaOpsMap(op);
1656LogicalResult Serializer::processOperation(Operation *opInst) {
1657 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1662 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1663 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1664 .Case([&](spirv::BranchConditionalOp op) {
1665 return processBranchConditionalOp(op);
1667 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1668 .Case([&](spirv::CompositeConstructOp op) {
1669 return processCompositeConstructOp(op);
1671 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1672 return processConstantCompositeReplicateOp(op);
1674 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1675 .Case([&](spirv::GraphARMOp op) {
return processGraphARMOp(op); })
1676 .Case([&](spirv::GraphEntryPointARMOp op) {
1677 return processGraphEntryPointARMOp(op);
1679 .Case([&](spirv::GraphOutputsARMOp op) {
1680 return processGraphOutputsARMOp(op);
1682 .Case([&](spirv::GlobalVariableOp op) {
1683 return processGlobalVariableOp(op);
1685 .Case([&](spirv::GraphConstantARMOp op) {
1686 return processGraphConstantARMOp(op);
1688 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1689 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1690 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1691 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1692 .Case([&](spirv::SpecConstantCompositeOp op) {
1693 return processSpecConstantCompositeOp(op);
1695 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1696 return processSpecConstantCompositeReplicateOp(op);
1698 .Case([&](spirv::SpecConstantOperationOp op) {
1699 return processSpecConstantOperationOp(op);
1701 .Case([&](spirv::SwitchOp op) {
return processSwitchOp(op); })
1702 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1703 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1708 [&](Operation *op) {
return dispatchToAutogenSerialization(op); });
1712Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1713 Location loc = op.getLoc();
1715 uint32_t resultTypeID = 0;
1716 if (
failed(processType(loc, op.getType(), resultTypeID)))
1719 uint32_t resultID = getNextID();
1720 valueIDMap[op.getResult()] = resultID;
1722 SmallVector<uint32_t, 8> operands;
1723 operands.reserve(2 + op.getConstituents().size());
1724 operands.push_back(resultTypeID);
1725 operands.push_back(resultID);
1726 for (Value constituent : op.getConstituents()) {
1727 uint32_t
id = getValueID(constituent);
1728 assert(
id &&
"use before def!");
1729 operands.push_back(
id);
1732 if (
failed(emitDebugLine(functionBody, loc)))
1735 encodeInstructionWithContinuationInto(
1736 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1738 for (
auto attr : op->getAttrs()) {
1739 if (
failed(processDecoration(loc, resultID, attr)))
1746LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1747 StringRef extInstSet,
1749 SmallVector<uint32_t, 4> operands;
1750 Location loc = op->
getLoc();
1752 uint32_t resultID = 0;
1754 uint32_t resultTypeID = 0;
1757 operands.push_back(resultTypeID);
1759 resultID = getNextID();
1760 operands.push_back(resultID);
1761 valueIDMap[op->
getResult(0)] = resultID;
1765 operands.push_back(getValueID(operand));
1769 if (
failed(emitDebugLine(functionBody, loc)))
1772 if (extInstSet.empty()) {
1776 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1782 if (
failed(processDecoration(loc, resultID, attr)))
1790void Serializer::updateTosaOpsMap(Operation *op) {
1791 if (!options.emitDebugInfo)
1794 if (
auto graphOp = dyn_cast<spirv::GraphARMOp>(op->
getParentOp())) {
1795 if (uint32_t graphID = getFunctionID(graphOp.getName()))
1796 tosaOpsMap[graphID][op->
getLoc()].insert(op);
1800LogicalResult Serializer::emitDecoration(uint32_t
target,
1801 spirv::Decoration decoration,
1802 ArrayRef<uint32_t> params) {
1803 uint32_t wordCount = 3 + params.size();
1804 llvm::append_values(
1807 static_cast<uint32_t
>(decoration));
1808 llvm::append_range(decorations, params);
1812LogicalResult Serializer::emitDecorationId(uint32_t
target,
1813 spirv::Decoration decoration,
1814 ArrayRef<uint32_t> operandIds) {
1815 uint32_t wordCount = 3 + operandIds.size();
1816 llvm::append_values(
1819 static_cast<uint32_t
>(decoration));
1820 llvm::append_range(decorations, operandIds);
1824LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1826 if (!options.emitDebugInfo)
1829 if (lastProcessedWasMergeInst) {
1830 lastProcessedWasMergeInst =
false;
1834 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1837 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType(unsigned) const
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
constexpr llvm::StringLiteral extTosa
Extension set name for TOSA ops.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::function_ref< Fn > function_ref
Attribute decorationValue