21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
102 for (
Operation &op : llvm::drop_begin(ops))
103 if (
auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
115 uint32_t wordCount = 1 + operands.size();
117 binary.append(operands.begin(), operands.end());
122 : module(module), mlirBuilder(module.
getContext()), options(options) {}
125 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
127 if (failed(module.verifyInvariants()))
132 if (failed(processExtension())) {
135 processMemoryModel();
142 for (
auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
148 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
160 binary.reserve(moduleSize);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
180 os <<
"\n= Value <id> Map =\n\n";
181 for (
auto valueIDPair : valueIDMap) {
182 Value val = valueIDPair.first;
183 os <<
" " << val <<
" "
184 <<
"id = " << valueIDPair.second <<
' ';
186 os <<
"from op '" << op->getName() <<
"'";
187 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
188 Block *block = arg.getOwner();
189 os <<
"from argument of block " << block <<
' ';
201uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
202 auto funcID = funcIDMap.lookup(fnName);
204 funcID = getNextID();
205 funcIDMap[fnName] = funcID;
210void Serializer::processCapability() {
211 for (
auto cap : module.getVceTriple()->getCapabilities())
213 {
static_cast<uint32_t
>(cap)});
216void Serializer::addLongCompositesCapability() {
217 if (longCompositesEmitted)
219 longCompositesEmitted =
true;
220 auto vceTriple =
module.getVceTriple();
221 if (!llvm::is_contained(vceTriple->getCapabilities(),
222 spirv::Capability::LongCompositesINTEL))
224 capabilities, spirv::Opcode::OpCapability,
225 {
static_cast<uint32_t
>(spirv::Capability::LongCompositesINTEL)});
226 if (!llvm::is_contained(vceTriple->getExtensions(),
227 spirv::Extension::SPV_INTEL_long_composites)) {
228 SmallVector<uint32_t, 8> extName;
231 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
236void Serializer::encodeInstructionWithContinuationInto(
237 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
238 ArrayRef<uint32_t> operands) {
244 std::optional<spirv::Opcode> continuationOp =
246 assert(continuationOp &&
"op is not a splittable composite/struct opcode");
250 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
251 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
255 addLongCompositesCapability();
258void Serializer::processDebugInfo() {
259 if (!options.emitDebugInfo)
261 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
262 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
263 fileID = getNextID();
264 SmallVector<uint32_t, 16> operands;
265 operands.push_back(fileID);
271LogicalResult Serializer::processExtension() {
272 llvm::SmallVector<uint32_t, 16> extName;
273 llvm::SmallSet<Extension, 4> deducedExts(
274 llvm::from_range, module.getVceTriple()->getExtensions());
275 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
276 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
278 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
279 return module.emitError(
280 "SPV_KHR_non_semantic_info extension not available");
281 deducedExts.insert(nonSemanticInfoExt);
283 for (spirv::Extension ext : deducedExts) {
291void Serializer::processMemoryModel() {
292 StringAttr memoryModelName =
module.getMemoryModelAttrName();
293 auto mm =
static_cast<uint32_t
>(
294 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
297 StringAttr addressingModelName =
module.getAddressingModelAttrName();
298 auto am =
static_cast<uint32_t
>(
299 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
308 if (attrName ==
"fp_fast_math_mode")
309 return "FPFastMathMode";
311 if (attrName ==
"fp_rounding_mode")
312 return "FPRoundingMode";
314 if (attrName ==
"cache_control_load_intel")
315 return "CacheControlLoadINTEL";
316 if (attrName ==
"cache_control_store_intel")
317 return "CacheControlStoreINTEL";
319 return llvm::convertToCamelFromSnakeCase(attrName,
true);
322template <
typename AttrTy,
typename EmitF>
325 StringRef attrName, EmitF emitter) {
326 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
328 return emitError(loc,
"expecting array attribute of ")
329 << attrName <<
" for " << stringifyDecoration(decoration);
331 if (arrayAttr.empty()) {
332 return emitError(loc,
"expecting non-empty array attribute of ")
333 << attrName <<
" for " << stringifyDecoration(decoration);
335 for (
Attribute attr : arrayAttr.getValue()) {
336 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
337 if (!cacheControlAttr) {
338 return emitError(loc,
"expecting array attribute of ")
339 << attrName <<
" for " << stringifyDecoration(decoration);
343 if (failed(emitter(cacheControlAttr)))
349LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
350 Decoration decoration,
353 switch (decoration) {
354 case spirv::Decoration::LinkageAttributes: {
357 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
358 auto linkageName = linkageAttr.getLinkageName();
359 auto linkageType = linkageAttr.getLinkageType().getValue();
363 args.push_back(
static_cast<uint32_t
>(linkageType));
366 case spirv::Decoration::FPFastMathMode:
367 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
368 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
371 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
372 << stringifyDecoration(decoration);
373 case spirv::Decoration::FPRoundingMode:
374 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
375 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
378 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
379 << stringifyDecoration(decoration);
380 case spirv::Decoration::Binding:
381 case spirv::Decoration::DescriptorSet:
382 case spirv::Decoration::Location:
383 case spirv::Decoration::Index:
384 case spirv::Decoration::Offset:
385 case spirv::Decoration::XfbBuffer:
386 case spirv::Decoration::XfbStride:
387 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
388 args.push_back(intAttr.getValue().getZExtValue());
391 return emitError(loc,
"expected integer attribute for ")
392 << stringifyDecoration(decoration);
393 case spirv::Decoration::BuiltIn:
394 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
395 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
397 args.push_back(
static_cast<uint32_t
>(*enumVal));
401 << stringifyDecoration(decoration) <<
" decoration attribute "
402 << strAttr.getValue();
404 return emitError(loc,
"expected string attribute for ")
405 << stringifyDecoration(decoration);
406 case spirv::Decoration::Aliased:
407 case spirv::Decoration::AliasedPointer:
408 case spirv::Decoration::Flat:
409 case spirv::Decoration::NonReadable:
410 case spirv::Decoration::NonWritable:
411 case spirv::Decoration::NoPerspective:
412 case spirv::Decoration::NoSignedWrap:
413 case spirv::Decoration::NoUnsignedWrap:
414 case spirv::Decoration::RelaxedPrecision:
415 case spirv::Decoration::Restrict:
416 case spirv::Decoration::RestrictPointer:
417 case spirv::Decoration::NoContraction:
418 case spirv::Decoration::Constant:
419 case spirv::Decoration::Block:
420 case spirv::Decoration::Invariant:
421 case spirv::Decoration::Patch:
422 case spirv::Decoration::Coherent:
425 if (isa<UnitAttr, DecorationAttr>(attr))
428 "expected unit attribute or decoration attribute for ")
429 << stringifyDecoration(decoration);
430 case spirv::Decoration::CacheControlLoadINTEL:
432 loc, decoration, attr,
"CacheControlLoadINTEL",
433 [&](CacheControlLoadINTELAttr attr) {
434 unsigned cacheLevel = attr.getCacheLevel();
435 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
436 return emitDecoration(
437 resultID, decoration,
438 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
440 case spirv::Decoration::CacheControlStoreINTEL:
442 loc, decoration, attr,
"CacheControlStoreINTEL",
443 [&](CacheControlStoreINTELAttr attr) {
444 unsigned cacheLevel = attr.getCacheLevel();
445 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
446 return emitDecoration(
447 resultID, decoration,
448 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
450 case spirv::Decoration::AlignmentId:
451 case spirv::Decoration::MaxByteOffsetId:
452 case spirv::Decoration::CounterBuffer: {
453 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
455 return emitError(loc,
"expected symbol reference for ")
456 << stringifyDecoration(decoration);
457 StringRef symName = symRef.getValue();
458 uint32_t operandID = getVariableID(symName);
460 operandID = getSpecConstID(symName);
462 return emitError(loc,
"could not find <id> for symbol '")
463 << symName <<
"' referenced by "
464 << stringifyDecoration(decoration);
465 return emitDecorationId(resultID, decoration, {operandID});
468 return emitError(loc,
"unhandled decoration ")
469 << stringifyDecoration(decoration);
471 return emitDecoration(resultID, decoration, args);
474LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
475 NamedAttribute attr) {
476 StringRef attrName = attr.
getName().strref();
478 std::optional<Decoration> decoration =
479 spirv::symbolizeDecoration(decorationName);
482 loc,
"non-argument attributes expected to have snake-case-ified "
483 "decoration name, unhandled attribute with name : ")
486 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
489LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
490 assert(!name.empty() &&
"unexpected empty string for OpName");
491 if (!options.emitSymbolName)
494 SmallVector<uint32_t, 4> nameOperands;
495 nameOperands.push_back(resultID);
502LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
506 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
512LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
516 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
521LogicalResult Serializer::processMemberDecoration(
526 static_cast<uint32_t
>(memberDecoration.
decoration)});
542bool Serializer::isInterfaceStructPtrType(Type type)
const {
543 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
544 switch (ptrType.getStorageClass()) {
545 case spirv::StorageClass::PhysicalStorageBuffer:
546 case spirv::StorageClass::PushConstant:
547 case spirv::StorageClass::StorageBuffer:
548 case spirv::StorageClass::Uniform:
549 return isa<spirv::StructType>(ptrType.getPointeeType());
557LogicalResult Serializer::processType(Location loc, Type type,
562 return processTypeImpl(loc, type, typeID, serializationCtx);
566Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
578 IntegerType::SignednessSemantics::Signless);
581 typeID = getTypeID(type);
585 typeID = getNextID();
586 SmallVector<uint32_t, 4> operands;
588 operands.push_back(typeID);
589 auto typeEnum = spirv::Opcode::OpTypeVoid;
590 bool deferSerialization =
false;
592 if ((isa<FunctionType>(type) &&
593 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
595 (isa<GraphType>(type) &&
597 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
598 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
599 deferSerialization, serializationCtx))) {
600 if (deferSerialization)
603 typeIDMap[type] = typeID;
605 if (typeEnum == spirv::Opcode::OpTypeStruct)
606 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
611 if (recursiveStructInfos.count(type) != 0) {
614 for (
auto &ptrInfo : recursiveStructInfos[type]) {
617 SmallVector<uint32_t, 4> ptrOperands;
618 ptrOperands.push_back(ptrInfo.pointerTypeID);
619 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
620 ptrOperands.push_back(typeIDMap[type]);
626 recursiveStructInfos[type].clear();
632 return emitError(loc,
"failed to process type: ") << type;
635LogicalResult Serializer::prepareBasicType(
636 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
637 SmallVectorImpl<uint32_t> &operands,
bool &deferSerialization,
639 deferSerialization =
false;
641 if (isVoidType(type)) {
642 typeEnum = spirv::Opcode::OpTypeVoid;
646 if (
auto intType = dyn_cast<IntegerType>(type)) {
647 if (intType.getWidth() == 1) {
648 typeEnum = spirv::Opcode::OpTypeBool;
652 typeEnum = spirv::Opcode::OpTypeInt;
653 operands.push_back(intType.getWidth());
658 operands.push_back(intType.isSigned() ? 1 : 0);
662 if (
auto floatType = dyn_cast<FloatType>(type)) {
663 typeEnum = spirv::Opcode::OpTypeFloat;
664 operands.push_back(floatType.getWidth());
665 if (floatType.isBF16()) {
666 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
668 if (floatType.isF8E4M3FN()) {
670 static_cast<uint32_t
>(spirv::FPEncoding::Float8E4M3EXT));
672 if (floatType.isF8E5M2()) {
674 static_cast<uint32_t
>(spirv::FPEncoding::Float8E5M2EXT));
680 if (
auto vectorType = dyn_cast<VectorType>(type)) {
681 uint32_t elementTypeID = 0;
682 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
683 serializationCtx))) {
686 typeEnum = spirv::Opcode::OpTypeVector;
687 operands.push_back(elementTypeID);
688 operands.push_back(vectorType.getNumElements());
692 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
693 typeEnum = spirv::Opcode::OpTypeImage;
694 uint32_t sampledTypeID = 0;
695 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
698 llvm::append_values(operands, sampledTypeID,
699 static_cast<uint32_t
>(imageType.getDim()),
700 static_cast<uint32_t
>(imageType.getDepthInfo()),
701 static_cast<uint32_t
>(imageType.getArrayedInfo()),
702 static_cast<uint32_t
>(imageType.getSamplingInfo()),
703 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
704 static_cast<uint32_t
>(imageType.getImageFormat()));
708 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
709 typeEnum = spirv::Opcode::OpTypeArray;
710 uint32_t elementTypeID = 0;
711 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
712 serializationCtx))) {
715 operands.push_back(elementTypeID);
716 if (
auto elementCountID = prepareConstantInt(
717 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
718 operands.push_back(elementCountID);
720 return processTypeDecoration(loc, arrayType, resultID);
723 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
724 uint32_t pointeeTypeID = 0;
725 spirv::StructType pointeeStruct =
726 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
729 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
734 SmallVector<uint32_t, 2> forwardPtrOperands;
735 forwardPtrOperands.push_back(resultID);
736 forwardPtrOperands.push_back(
737 static_cast<uint32_t
>(ptrType.getStorageClass()));
740 spirv::Opcode::OpTypeForwardPointer,
752 deferSerialization =
true;
756 recursiveStructInfos[structType].push_back(
757 {resultID, ptrType.getStorageClass()});
759 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
764 typeEnum = spirv::Opcode::OpTypePointer;
765 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
766 operands.push_back(pointeeTypeID);
771 if (isInterfaceStructPtrType(ptrType)) {
772 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
773 if (!structType.hasDecoration(spirv::Decoration::Block))
774 if (
failed(emitDecoration(getTypeID(pointeeStruct),
775 spirv::Decoration::Block)))
776 return emitError(loc,
"cannot decorate ")
777 << pointeeStruct <<
" with Block decoration";
783 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
784 uint32_t elementTypeID = 0;
785 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
786 elementTypeID, serializationCtx))) {
789 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
790 operands.push_back(elementTypeID);
791 return processTypeDecoration(loc, runtimeArrayType, resultID);
794 if (isa<spirv::SamplerType>(type)) {
795 typeEnum = spirv::Opcode::OpTypeSampler;
799 if (isa<spirv::NamedBarrierType>(type)) {
800 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
804 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
805 typeEnum = spirv::Opcode::OpTypeSampledImage;
806 uint32_t imageTypeID = 0;
808 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
811 operands.push_back(imageTypeID);
815 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
816 if (structType.isIdentified()) {
817 if (
failed(processName(resultID, structType.getIdentifier())))
819 serializationCtx.insert(structType.getIdentifier());
822 bool hasOffset = structType.hasOffset();
823 for (
auto elementIndex :
824 llvm::seq<uint32_t>(0, structType.getNumElements())) {
825 uint32_t elementTypeID = 0;
826 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
827 elementTypeID, serializationCtx))) {
830 operands.push_back(elementTypeID);
832 auto intType = IntegerType::get(structType.getContext(), 32);
834 spirv::StructType::MemberDecorationInfo offsetDecoration{
835 elementIndex, spirv::Decoration::Offset,
836 IntegerAttr::get(intType,
837 structType.getMemberOffset(elementIndex))};
838 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
839 return emitError(loc,
"cannot decorate ")
840 << elementIndex <<
"-th member of " << structType
841 <<
" with its offset";
845 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
846 structType.getMemberDecorations(memberDecorations);
848 for (
auto &memberDecoration : memberDecorations) {
849 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
850 return emitError(loc,
"cannot decorate ")
851 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
852 <<
"-th member of " << structType <<
" with "
853 << stringifyDecoration(memberDecoration.
decoration);
857 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
858 structType.getStructDecorations(structDecorations);
860 for (spirv::StructType::StructDecorationInfo &structDecoration :
862 if (
failed(processDecorationAttr(loc, resultID,
863 structDecoration.decoration,
864 structDecoration.decorationValue))) {
865 return emitError(loc,
"cannot decorate struct ")
866 << structType <<
" with "
867 << stringifyDecoration(structDecoration.decoration);
871 typeEnum = spirv::Opcode::OpTypeStruct;
873 if (structType.isIdentified())
874 serializationCtx.remove(structType.getIdentifier());
879 if (
auto cooperativeMatrixType =
880 dyn_cast<spirv::CooperativeMatrixType>(type)) {
881 uint32_t elementTypeID = 0;
882 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
883 elementTypeID, serializationCtx))) {
886 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
887 auto getConstantOp = [&](uint32_t id) {
888 auto attr = IntegerAttr::get(IntegerType::get(type.
getContext(), 32),
id);
889 return prepareConstantInt(loc, attr);
892 operands, elementTypeID,
893 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
894 getConstantOp(cooperativeMatrixType.getRows()),
895 getConstantOp(cooperativeMatrixType.getColumns()),
896 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
900 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
901 uint32_t elementTypeID = 0;
902 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
903 serializationCtx))) {
906 typeEnum = spirv::Opcode::OpTypeMatrix;
907 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
911 if (
auto tensorArmType = dyn_cast<TensorArmType>(type)) {
912 uint32_t elementTypeID = 0;
914 uint32_t shapeID = 0;
916 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
917 elementTypeID, serializationCtx))) {
920 if (tensorArmType.hasRank()) {
921 ArrayRef<int64_t> dims = tensorArmType.getShape();
923 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
928 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
929 if (rank > 0 && shaped) {
930 auto I32Type = IntegerType::get(type.
getContext(), 32);
933 SmallVector<uint64_t, 1> index(rank);
934 shapeID = prepareDenseElementsConstant(
936 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
939 shapeID = prepareArrayConstant(
941 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
948 typeEnum = spirv::Opcode::OpTypeTensorARM;
949 operands.push_back(elementTypeID);
952 operands.push_back(rankID);
955 operands.push_back(shapeID);
960 return emitError(loc,
"unhandled type in serialization: ") << type;
964Serializer::prepareFunctionType(Location loc, FunctionType type,
965 spirv::Opcode &typeEnum,
966 SmallVectorImpl<uint32_t> &operands) {
967 typeEnum = spirv::Opcode::OpTypeFunction;
968 assert(type.getNumResults() <= 1 &&
969 "serialization supports only a single return value");
970 uint32_t resultID = 0;
972 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
976 operands.push_back(resultID);
977 for (
auto &res : type.getInputs()) {
978 uint32_t argTypeID = 0;
979 if (
failed(processType(loc, res, argTypeID))) {
982 operands.push_back(argTypeID);
988Serializer::prepareGraphType(Location loc, GraphType type,
989 spirv::Opcode &typeEnum,
990 SmallVectorImpl<uint32_t> &operands) {
991 typeEnum = spirv::Opcode::OpTypeGraphARM;
992 assert(type.getNumResults() >= 1 &&
993 "serialization requires at least a return value");
995 operands.push_back(type.getNumInputs());
997 for (Type argType : type.getInputs()) {
998 uint32_t argTypeID = 0;
999 if (
failed(processType(loc, argType, argTypeID)))
1001 operands.push_back(argTypeID);
1004 for (Type resType : type.getResults()) {
1005 uint32_t resTypeID = 0;
1006 if (
failed(processType(loc, resType, resTypeID)))
1008 operands.push_back(resTypeID);
1018uint32_t Serializer::prepareConstant(Location loc, Type constType,
1019 Attribute valueAttr) {
1020 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
1027 if (
auto id = getConstantID(valueAttr)) {
1031 uint32_t typeID = 0;
1032 if (
failed(processType(loc, constType, typeID))) {
1036 uint32_t resultID = 0;
1037 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1038 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1039 SmallVector<uint64_t, 4> index(rank);
1040 resultID = prepareDenseElementsConstant(loc, constType, attr,
1042 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1043 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1046 if (resultID == 0) {
1047 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
1051 constIDMap[valueAttr] = resultID;
1055uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1057 uint32_t typeID = 0;
1058 if (
failed(processType(loc, constType, typeID))) {
1062 uint32_t resultID = getNextID();
1063 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1064 operands.reserve(attr.size() + 2);
1065 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
1066 for (Attribute elementAttr : attr) {
1067 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1068 operands.push_back(elementID);
1073 encodeInstructionWithContinuationInto(
1074 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1082Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1083 DenseElementsAttr valueAttr,
int dim,
1084 MutableArrayRef<uint64_t> index) {
1085 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
1086 assert(dim <= shapedType.getRank());
1087 if (shapedType.getRank() == dim) {
1088 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1089 return attr.getType().getElementType().isInteger(1)
1090 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1091 : prepareConstantInt(loc,
1092 attr.getValues<IntegerAttr>()[index]);
1094 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1095 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1100 uint32_t typeID = 0;
1101 if (
failed(processType(loc, constType, typeID))) {
1105 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1106 uint32_t resultID = getNextID();
1107 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1108 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1109 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1110 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1111 if (!innerShape.empty())
1119 if (isa<spirv::CooperativeMatrixType>(constType)) {
1123 "cannot serialize a non-splat value for a cooperative matrix type");
1128 operands.reserve(3);
1131 if (
auto elementID = prepareDenseElementsConstant(
1132 loc, elementType, valueAttr, shapedType.getRank(), index)) {
1133 operands.push_back(elementID);
1137 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1139 {typeID, resultID});
1142 operands.reserve(numberOfConstituents + 2);
1143 for (
int i = 0; i < numberOfConstituents; ++i) {
1145 if (
auto elementID = prepareDenseElementsConstant(
1146 loc, elementType, valueAttr, dim + 1, index)) {
1147 operands.push_back(elementID);
1153 encodeInstructionWithContinuationInto(
1154 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1159uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1161 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1162 return prepareConstantFp(loc, floatAttr, isSpec);
1164 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1165 return prepareConstantBool(loc, boolAttr, isSpec);
1167 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1168 return prepareConstantInt(loc, intAttr, isSpec);
1174uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1178 if (
auto id = getConstantID(boolAttr)) {
1184 uint32_t typeID = 0;
1185 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1189 auto resultID = getNextID();
1191 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1192 : spirv::Opcode::OpConstantTrue)
1193 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1194 : spirv::Opcode::OpConstantFalse);
1198 constIDMap[boolAttr] = resultID;
1203uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1207 if (
auto id = getConstantID(intAttr)) {
1213 uint32_t typeID = 0;
1214 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1218 auto resultID = getNextID();
1219 APInt value = intAttr.getValue();
1220 unsigned bitwidth = value.getBitWidth();
1221 bool isSigned = intAttr.getType().isSignedInteger();
1223 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1236 word =
static_cast<int32_t
>(value.getSExtValue());
1238 word =
static_cast<uint32_t
>(value.getZExtValue());
1250 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1252 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1255 {typeID, resultID, words.word1, words.word2});
1258 std::string valueStr;
1259 llvm::raw_string_ostream rss(valueStr);
1260 value.print(rss,
false);
1263 << bitwidth <<
"-bit integer literal: " << valueStr;
1269 constIDMap[intAttr] = resultID;
1274uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1275 IntegerAttr intAttr) {
1277 if (uint32_t
id = getGraphConstantARMId(intAttr)) {
1282 uint32_t typeID = 0;
1283 if (
failed(processType(loc, graphConstType, typeID))) {
1287 uint32_t resultID = getNextID();
1288 APInt value = intAttr.getValue();
1289 unsigned bitwidth = value.getBitWidth();
1290 if (bitwidth > 32) {
1291 emitError(loc,
"Too wide attribute for OpGraphConstantARM: ")
1292 << bitwidth <<
" bits";
1295 bool isSigned = value.isSignedIntN(bitwidth);
1299 word =
static_cast<int32_t
>(value.getSExtValue());
1301 word =
static_cast<uint32_t
>(value.getZExtValue());
1304 {typeID, resultID, word});
1305 graphConstIDMap[intAttr] = resultID;
1309uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1313 if (
auto id = getConstantID(floatAttr)) {
1319 uint32_t typeID = 0;
1320 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1324 auto resultID = getNextID();
1325 APFloat value = floatAttr.getValue();
1326 const llvm::fltSemantics *semantics = &value.getSemantics();
1329 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1331 if (semantics == &APFloat::IEEEsingle()) {
1332 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1334 }
else if (semantics == &APFloat::IEEEdouble()) {
1338 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1340 {typeID, resultID, words.word1, words.word2});
1341 }
else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1342 &APFloat::Float8E4M3FN(),
1343 &APFloat::Float8E5M2()},
1346 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1349 std::string valueStr;
1350 llvm::raw_string_ostream rss(valueStr);
1354 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1359 constIDMap[floatAttr] = resultID;
1368 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1369 return typedAttr.getType();
1372 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1379uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1382 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1383 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1387 uint32_t typeID = 0;
1388 if (
failed(processType(loc, resultType, typeID))) {
1396 auto compositeType = dyn_cast<CompositeType>(resultType);
1399 Type elementType = compositeType.getElementType(0);
1401 uint32_t constandID;
1402 if (elementType == valueType) {
1403 constandID = prepareConstant(loc, elementType, valueAttr);
1405 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1408 uint32_t resultID = getNextID();
1409 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1411 {typeID, resultID});
1414 spirv::Opcode::OpConstantCompositeReplicateEXT,
1415 {typeID, resultID, constandID});
1418 constCompositeReplicateIDMap[valueTypePair] = resultID;
1426uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1427 if (uint32_t
id = getBlockID(block))
1429 return blockIDMap[block] = getNextID();
1433void Serializer::printBlock(
Block *block, raw_ostream &os) {
1434 os <<
"block " << block <<
" (id = ";
1435 if (uint32_t
id = getBlockID(block))
1444Serializer::processBlock(
Block *block,
bool omitLabel,
1446 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1447 LLVM_DEBUG(block->
print(llvm::dbgs()));
1448 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1450 uint32_t blockID = getOrCreateBlockID(block);
1451 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1458 if (
failed(emitPhiForBlockArguments(block)))
1468 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1471 emitMerge =
nullptr;
1474 uint32_t blockID = getNextID();
1480 for (Operation &op : llvm::drop_end(*block)) {
1481 if (
failed(processOperation(&op)))
1489 if (
failed(processOperation(&block->
back())))
1495LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1501 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1508 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1510 auto *terminator = mlirPredecessor->getTerminator();
1511 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1512 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1513 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1522 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1523 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1524 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1525 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1526 }
else if (
auto branchCondOp =
1527 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1528 std::optional<OperandRange> blockOperands;
1529 if (branchCondOp.getTrueTarget() == block) {
1530 blockOperands = branchCondOp.getTrueTargetOperands();
1532 assert(branchCondOp.getFalseTarget() == block);
1533 blockOperands = branchCondOp.getFalseTargetOperands();
1535 assert(!blockOperands->empty() &&
1536 "expected non-empty block operand range");
1537 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1538 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1539 std::optional<OperandRange> blockOperands;
1540 if (block == switchOp.getDefaultTarget()) {
1541 blockOperands = switchOp.getDefaultOperands();
1543 SuccessorRange targets = switchOp.getTargets();
1544 auto it = llvm::find(targets, block);
1545 assert(it != targets.end());
1546 size_t index = std::distance(targets.begin(), it);
1547 blockOperands = switchOp.getTargetOperands(index);
1549 assert(!blockOperands->empty() &&
1550 "expected non-empty block operand range");
1551 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1553 return terminator->emitError(
"unimplemented terminator for Phi creation");
1556 llvm::dbgs() <<
" block arguments:\n";
1557 for (Value v : predecessors.back().second)
1558 llvm::dbgs() <<
" " << v <<
"\n";
1563 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1567 uint32_t phiTypeID = 0;
1570 uint32_t phiID = getNextID();
1572 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1573 << arg <<
" (id = " << phiID <<
")\n");
1576 SmallVector<uint32_t, 8> phiArgs;
1577 phiArgs.push_back(phiTypeID);
1578 phiArgs.push_back(phiID);
1580 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1581 Value value = predecessors[predIndex].second[argIndex];
1582 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1583 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1584 <<
") value " << value <<
' ');
1586 uint32_t valueId = getValueID(value);
1590 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1591 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1594 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1596 phiArgs.push_back(valueId);
1598 phiArgs.push_back(predBlockId);
1602 valueIDMap[arg] = phiID;
1612LogicalResult Serializer::encodeExtensionInstruction(
1613 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1614 ArrayRef<uint32_t> operands) {
1616 auto &setID = extendedInstSetIDMap[extensionSetName];
1618 setID = getNextID();
1619 SmallVector<uint32_t, 16> importOperands;
1620 importOperands.push_back(setID);
1628 if (operands.size() < 2) {
1629 return op->
emitError(
"extended instructions must have a result encoding");
1631 SmallVector<uint32_t, 8> extInstOperands;
1632 extInstOperands.reserve(operands.size() + 2);
1633 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1634 extInstOperands.push_back(setID);
1635 extInstOperands.push_back(extensionOpcode);
1636 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1642LogicalResult Serializer::processOperation(Operation *opInst) {
1643 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1648 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1649 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1650 .Case([&](spirv::BranchConditionalOp op) {
1651 return processBranchConditionalOp(op);
1653 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1654 .Case([&](spirv::CompositeConstructOp op) {
1655 return processCompositeConstructOp(op);
1657 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1658 return processConstantCompositeReplicateOp(op);
1660 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1661 .Case([&](spirv::GraphARMOp op) {
return processGraphARMOp(op); })
1662 .Case([&](spirv::GraphEntryPointARMOp op) {
1663 return processGraphEntryPointARMOp(op);
1665 .Case([&](spirv::GraphOutputsARMOp op) {
1666 return processGraphOutputsARMOp(op);
1668 .Case([&](spirv::GlobalVariableOp op) {
1669 return processGlobalVariableOp(op);
1671 .Case([&](spirv::GraphConstantARMOp op) {
1672 return processGraphConstantARMOp(op);
1674 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1675 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1676 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1677 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1678 .Case([&](spirv::SpecConstantCompositeOp op) {
1679 return processSpecConstantCompositeOp(op);
1681 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1682 return processSpecConstantCompositeReplicateOp(op);
1684 .Case([&](spirv::SpecConstantOperationOp op) {
1685 return processSpecConstantOperationOp(op);
1687 .Case([&](spirv::SwitchOp op) {
return processSwitchOp(op); })
1688 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1689 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1694 [&](Operation *op) {
return dispatchToAutogenSerialization(op); });
1698Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1699 Location loc = op.getLoc();
1701 uint32_t resultTypeID = 0;
1702 if (
failed(processType(loc, op.getType(), resultTypeID)))
1705 uint32_t resultID = getNextID();
1706 valueIDMap[op.getResult()] = resultID;
1708 SmallVector<uint32_t, 8> operands;
1709 operands.reserve(2 + op.getConstituents().size());
1710 operands.push_back(resultTypeID);
1711 operands.push_back(resultID);
1712 for (Value constituent : op.getConstituents()) {
1713 uint32_t
id = getValueID(constituent);
1714 assert(
id &&
"use before def!");
1715 operands.push_back(
id);
1718 if (
failed(emitDebugLine(functionBody, loc)))
1721 encodeInstructionWithContinuationInto(
1722 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1724 for (
auto attr : op->getAttrs()) {
1725 if (
failed(processDecoration(loc, resultID, attr)))
1732LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1733 StringRef extInstSet,
1735 SmallVector<uint32_t, 4> operands;
1736 Location loc = op->
getLoc();
1738 uint32_t resultID = 0;
1740 uint32_t resultTypeID = 0;
1743 operands.push_back(resultTypeID);
1745 resultID = getNextID();
1746 operands.push_back(resultID);
1747 valueIDMap[op->
getResult(0)] = resultID;
1751 operands.push_back(getValueID(operand));
1753 if (
failed(emitDebugLine(functionBody, loc)))
1756 if (extInstSet.empty()) {
1760 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1766 if (
failed(processDecoration(loc, resultID, attr)))
1774LogicalResult Serializer::emitDecoration(uint32_t
target,
1775 spirv::Decoration decoration,
1776 ArrayRef<uint32_t> params) {
1777 uint32_t wordCount = 3 + params.size();
1778 llvm::append_values(
1781 static_cast<uint32_t
>(decoration));
1782 llvm::append_range(decorations, params);
1786LogicalResult Serializer::emitDecorationId(uint32_t
target,
1787 spirv::Decoration decoration,
1788 ArrayRef<uint32_t> operandIds) {
1789 uint32_t wordCount = 3 + operandIds.size();
1790 llvm::append_values(
1793 static_cast<uint32_t
>(decoration));
1794 llvm::append_range(decorations, operandIds);
1798LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1800 if (!options.emitDebugInfo)
1803 if (lastProcessedWasMergeInst) {
1804 lastProcessedWasMergeInst =
false;
1808 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1811 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::function_ref< Fn > function_ref
Attribute decorationValue