23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecorate with ")
238 << decorationName <<
" needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::Location:
289 case spirv::Decoration::SpecId:
290 case spirv::Decoration::Index:
291 case spirv::Decoration::Offset:
292 case spirv::Decoration::XfbBuffer:
293 case spirv::Decoration::XfbStride:
294 if (words.size() != 3) {
295 return emitError(unknownLoc,
"OpDecorate with ")
296 << decorationName <<
" needs a single integer literal";
298 decorations[words[0]].set(
299 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
301 case spirv::Decoration::BuiltIn:
302 if (words.size() != 3) {
303 return emitError(unknownLoc,
"OpDecorate with ")
304 << decorationName <<
" needs a single integer literal";
306 decorations[words[0]].set(
307 symbol, opBuilder.getStringAttr(
308 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
310 case spirv::Decoration::ArrayStride:
311 if (words.size() != 3) {
312 return emitError(unknownLoc,
"OpDecorate with ")
313 << decorationName <<
" needs a single integer literal";
315 typeDecorations[words[0]] = words[2];
317 case spirv::Decoration::LinkageAttributes: {
318 if (words.size() < 4) {
319 return emitError(unknownLoc,
"OpDecorate with ")
321 <<
" needs at least 1 string and 1 integer literal";
329 unsigned wordIndex = 2;
331 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
332 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
333 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
334 StringAttr::get(context, linkageName), linkageTypeAttr);
335 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
338 case spirv::Decoration::Aliased:
339 case spirv::Decoration::AliasedPointer:
340 case spirv::Decoration::Block:
341 case spirv::Decoration::BufferBlock:
342 case spirv::Decoration::Flat:
343 case spirv::Decoration::NonReadable:
344 case spirv::Decoration::NonWritable:
345 case spirv::Decoration::NoPerspective:
346 case spirv::Decoration::NoSignedWrap:
347 case spirv::Decoration::NoUnsignedWrap:
348 case spirv::Decoration::RelaxedPrecision:
349 case spirv::Decoration::Restrict:
350 case spirv::Decoration::RestrictPointer:
351 case spirv::Decoration::NoContraction:
352 case spirv::Decoration::Constant:
353 case spirv::Decoration::Invariant:
354 case spirv::Decoration::Patch:
355 case spirv::Decoration::Coherent:
356 if (words.size() != 2) {
357 return emitError(unknownLoc,
"OpDecorate with ")
358 << decorationName <<
" needs a single target <id>";
360 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
362 case spirv::Decoration::CacheControlLoadINTEL: {
364 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
365 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
371 case spirv::Decoration::CacheControlStoreINTEL: {
373 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
374 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
381 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
387spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
389 if (words.size() < 3) {
391 "OpMemberDecorate must have at least 3 operands");
394 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
395 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
397 " missing offset specification in OpMemberDecorate with "
398 "Offset decoration");
400 ArrayRef<uint32_t> decorationOperands;
401 if (words.size() > 3) {
402 decorationOperands = words.slice(3);
404 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
409 if (words.size() < 3) {
410 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
412 unsigned wordIndex = 2;
414 if (wordIndex != words.size()) {
416 "unexpected trailing words in OpMemberName instruction");
418 memberNameMap[words[0]][words[1]] = name;
424 if (!decorations.contains(argID)) {
425 argAttrs[argIndex] = DictionaryAttr::get(context, {});
429 spirv::DecorationAttr foundDecorationAttr;
431 for (
auto decoration :
432 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
433 spirv::Decoration::AliasedPointer,
434 spirv::Decoration::RestrictPointer}) {
436 if (decAttr.getName() !=
440 if (foundDecorationAttr)
442 "more than one Aliased/Restrict decorations for "
443 "function argument with result <id> ")
446 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
451 spirv::Decoration::RelaxedPrecision))) {
456 if (foundDecorationAttr)
457 return emitError(unknownLoc,
"already found a decoration for function "
458 "argument with result <id> ")
461 foundDecorationAttr = spirv::DecorationAttr::get(
462 context, spirv::Decoration::RelaxedPrecision);
466 if (!foundDecorationAttr)
467 return emitError(unknownLoc,
"unimplemented decoration support for "
468 "function argument with result <id> ")
471 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
472 foundDecorationAttr);
473 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
480 return emitError(unknownLoc,
"found function inside function");
484 if (operands.size() != 4) {
485 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
489 return emitError(unknownLoc,
"undefined result type from <id> ")
493 uint32_t fnID = operands[1];
494 if (funcMap.count(fnID)) {
495 return emitError(unknownLoc,
"duplicate function definition/declaration");
498 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
500 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
504 if (!fnType || !isa<FunctionType>(fnType)) {
505 return emitError(unknownLoc,
"unknown function type from <id> ")
508 auto functionType = cast<FunctionType>(fnType);
510 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
511 (functionType.getNumResults() == 1 &&
512 functionType.getResult(0) != resultType)) {
513 return emitError(unknownLoc,
"mismatch in function type ")
514 << functionType <<
" and return type " << resultType <<
" specified";
518 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
519 functionType, fnControl.value());
521 if (decorations.count(fnID)) {
522 for (
auto attr : decorations[fnID].getAttrs()) {
523 funcOp->setAttr(attr.getName(), attr.getValue());
526 curFunction = funcMap[fnID] = funcOp;
527 auto *entryBlock = funcOp.addEntryBlock();
530 <<
"//===-------------------------------------------===//\n";
531 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
532 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
533 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
534 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
539 argAttrs.resize(functionType.getNumInputs());
542 if (functionType.getNumInputs()) {
543 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
544 auto argType = functionType.getInput(i);
545 spirv::Opcode opcode = spirv::Opcode::OpNop;
548 spirv::Opcode::OpFunctionParameter))) {
551 if (opcode != spirv::Opcode::OpFunctionParameter) {
554 "missing OpFunctionParameter instruction for argument ")
557 if (operands.size() != 2) {
560 "expected result type and result <id> for OpFunctionParameter");
562 auto argDefinedType =
getType(operands[0]);
563 if (!argDefinedType || argDefinedType != argType) {
565 "mismatch in argument type between function type "
567 << functionType <<
" and argument type definition "
568 << argDefinedType <<
" at argument " << i;
571 return emitError(unknownLoc,
"duplicate definition of result <id> ")
578 auto argValue = funcOp.getArgument(i);
579 valueMap[operands[1]] = argValue;
583 if (llvm::any_of(argAttrs, [](
Attribute attr) {
584 auto argAttr = cast<DictionaryAttr>(attr);
585 return !argAttr.empty();
587 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
592 auto linkageAttr = funcOp.getLinkageAttributes();
593 auto hasImportLinkage =
594 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
595 spirv::LinkageType::Import);
596 if (hasImportLinkage)
603 spirv::Opcode opcode = spirv::Opcode::OpNop;
612 spirv::Opcode::OpFunctionEnd))) {
615 if (opcode == spirv::Opcode::OpFunctionEnd) {
618 if (opcode != spirv::Opcode::OpLabel) {
619 return emitError(unknownLoc,
"a basic block must start with OpLabel");
621 if (instOperands.size() != 1) {
622 return emitError(unknownLoc,
"OpLabel should only have result <id>");
624 blockMap[instOperands[0]] = entryBlock;
632 spirv::Opcode::OpFunctionEnd)) &&
633 opcode != spirv::Opcode::OpFunctionEnd) {
638 if (opcode != spirv::Opcode::OpFunctionEnd) {
648 if (!operands.empty()) {
649 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
660 curFunction = std::nullopt;
665 <<
"//===-------------------------------------------===//\n";
672 if (operands.size() < 2) {
674 "missing graph defintion in OpGraphEntryPointARM");
677 unsigned wordIndex = 0;
678 uint32_t graphID = operands[wordIndex++];
679 if (!graphMap.contains(graphID)) {
681 "missing graph definition/declaration with id ")
685 spirv::GraphARMOp graphARM = graphMap[graphID];
687 graphARM.setSymName(name);
688 graphARM.setEntryPoint(
true);
691 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
693 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
695 return emitError(unknownLoc,
"undefined result <id> ")
696 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
702 opBuilder.setInsertionPoint(graphARM);
703 spirv::GraphEntryPointARMOp::create(
704 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
705 opBuilder.getArrayAttr(interface));
713 return emitError(unknownLoc,
"found graph inside graph");
716 if (operands.size() < 2) {
717 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
721 if (!type || !isa<GraphType>(type)) {
722 return emitError(unknownLoc,
"unknown graph type from <id> ")
725 auto graphType = cast<GraphType>(type);
726 if (graphType.getNumResults() <= 0) {
727 return emitError(unknownLoc,
"expected at least one result");
730 uint32_t graphID = operands[1];
731 if (graphMap.count(graphID)) {
732 return emitError(unknownLoc,
"duplicate graph definition/declaration");
737 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
738 curGraph = graphMap[graphID] = graphOp;
739 Block *entryBlock = graphOp.addEntryBlock();
742 <<
"//===-------------------------------------------===//\n";
743 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
744 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
745 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
746 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
751 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
752 spirv::Opcode opcode;
755 spirv::Opcode::OpGraphInputARM))) {
758 if (operands.size() != 3) {
759 return emitError(unknownLoc,
"expected result type, result <id> and "
760 "input index for OpGraphInputARM");
764 if (!argDefinedType) {
765 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
768 if (argDefinedType != argType) {
770 "mismatch in argument type between graph type "
772 << graphType <<
" and argument type definition " << argDefinedType
773 <<
" at argument " <<
index;
776 return emitError(unknownLoc,
"duplicate definition of result <id> ")
781 if (!inputIndexAttr) {
783 "unable to read inputIndex value from constant op ")
786 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
787 valueMap[operands[1]] = argValue;
790 graphOutputs.resize(graphType.getNumResults());
796 blockMap[graphID] = entryBlock;
803 spirv::Opcode opcode;
813 }
while (opcode != spirv::Opcode::OpGraphEndARM);
820 if (operands.size() != 2) {
823 "expected value id and output index for OpGraphSetOutputARM");
826 uint32_t
id = operands[0];
829 return emitError(unknownLoc,
"could not find result <id> ") << id;
833 if (!outputIndexAttr) {
835 "unable to read outputIndex value from constant op ")
838 graphOutputs[outputIndexAttr.getInt()] = value;
845 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
848 if (!operands.empty()) {
849 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
853 curGraph = std::nullopt;
854 graphOutputs.clear();
859 <<
"//===-------------------------------------------===//\n";
864std::optional<std::pair<Attribute, Type>>
866 auto constIt = constantMap.find(
id);
867 if (constIt == constantMap.end())
869 return constIt->getSecond();
872std::optional<std::pair<Attribute, Type>>
874 if (
auto it = constantCompositeReplicateMap.find(
id);
875 it != constantCompositeReplicateMap.end())
880std::optional<spirv::SpecConstOperationMaterializationInfo>
882 auto constIt = specConstOperationMap.find(
id);
883 if (constIt == specConstOperationMap.end())
885 return constIt->getSecond();
889 auto funcName = nameMap.lookup(
id).str();
890 if (funcName.empty()) {
891 funcName =
"spirv_fn_" + std::to_string(
id);
897 std::string graphName = nameMap.lookup(
id).str();
898 if (graphName.empty()) {
899 graphName =
"spirv_graph_" + std::to_string(
id);
905 auto constName = nameMap.lookup(
id).str();
906 if (constName.empty()) {
907 constName =
"spirv_spec_const_" + std::to_string(
id);
914 TypedAttr defaultValue) {
916 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
918 if (decorations.count(resultID)) {
919 for (
auto attr : decorations[resultID].getAttrs())
920 op->setAttr(attr.getName(), attr.getValue());
922 specConstMap[resultID] = op;
926std::optional<spirv::GraphConstantARMOpMaterializationInfo>
928 auto graphConstIt = graphConstantMap.find(
id);
929 if (graphConstIt == graphConstantMap.end())
931 return graphConstIt->getSecond();
936 unsigned wordIndex = 0;
937 if (operands.size() < 3) {
940 "OpVariable needs at least 3 operands, type, <id> and storage class");
944 auto type =
getType(operands[wordIndex]);
946 return emitError(unknownLoc,
"unknown result type <id> : ")
947 << operands[wordIndex];
949 auto ptrType = dyn_cast<spirv::PointerType>(type);
952 "expected a result type <id> to be a spirv.ptr, found : ")
958 auto variableID = operands[wordIndex];
959 auto variableName = nameMap.lookup(variableID).str();
960 if (variableName.empty()) {
961 variableName =
"spirv_var_" + std::to_string(variableID);
966 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
967 if (ptrType.getStorageClass() != storageClass) {
968 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
969 << type <<
" and that specified in OpVariable instruction : "
970 << stringifyStorageClass(storageClass);
977 if (wordIndex < operands.size()) {
987 return emitError(unknownLoc,
"unknown <id> ")
988 << operands[wordIndex] <<
"used as initializer";
990 initializer = SymbolRefAttr::get(op);
993 if (wordIndex != operands.size()) {
995 "found more operands than expected when deserializing "
996 "OpVariable instruction, only ")
997 << wordIndex <<
" of " << operands.size() <<
" processed";
1000 auto varOp = spirv::GlobalVariableOp::create(
1001 opBuilder, loc, TypeAttr::get(type),
1002 opBuilder.getStringAttr(variableName), initializer);
1005 if (decorations.count(variableID)) {
1006 for (
auto attr : decorations[variableID].getAttrs())
1007 varOp->setAttr(attr.getName(), attr.getValue());
1009 globalVariableMap[variableID] = varOp;
1018 return dyn_cast<IntegerAttr>(constInfo->first);
1022 if (operands.size() < 2) {
1023 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1025 if (!nameMap.lookup(operands[0]).empty()) {
1026 return emitError(unknownLoc,
"duplicate name found for result <id> ")
1029 unsigned wordIndex = 1;
1031 if (wordIndex != operands.size()) {
1033 "unexpected trailing words in OpName instruction");
1035 nameMap[operands[0]] = name;
1045 if (operands.empty()) {
1046 return emitError(unknownLoc,
"type instruction with opcode ")
1047 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1052 if (typeMap.count(operands[0])) {
1053 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1058 case spirv::Opcode::OpTypeVoid:
1059 if (operands.size() != 1)
1060 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1061 typeMap[operands[0]] = opBuilder.getNoneType();
1063 case spirv::Opcode::OpTypeBool:
1064 if (operands.size() != 1)
1065 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1066 typeMap[operands[0]] = opBuilder.getI1Type();
1068 case spirv::Opcode::OpTypeInt: {
1069 if (operands.size() != 3)
1071 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1080 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1081 : IntegerType::SignednessSemantics::Signless;
1082 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1084 case spirv::Opcode::OpTypeFloat: {
1085 if (operands.size() != 2 && operands.size() != 3)
1087 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1088 "or 3 operands (type, bitwidth, encoding), but got ")
1090 uint32_t bitWidth = operands[1];
1093 if (operands.size() == 2) {
1096 floatTy = opBuilder.getF16Type();
1099 floatTy = opBuilder.getF32Type();
1102 floatTy = opBuilder.getF64Type();
1105 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1110 if (operands.size() == 3) {
1111 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1113 floatTy = opBuilder.getBF16Type();
1114 else if (spirv::FPEncoding(operands[2]) ==
1115 spirv::FPEncoding::Float8E4M3EXT &&
1117 floatTy = opBuilder.getF8E4M3FNType();
1118 else if (spirv::FPEncoding(operands[2]) ==
1119 spirv::FPEncoding::Float8E5M2EXT &&
1121 floatTy = opBuilder.getF8E5M2Type();
1123 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1124 << operands[2] <<
" and bitWidth " << bitWidth;
1127 typeMap[operands[0]] = floatTy;
1129 case spirv::Opcode::OpTypeVector: {
1130 if (operands.size() != 3) {
1133 "OpTypeVector must have element type and count parameters");
1137 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1140 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1142 case spirv::Opcode::OpTypePointer: {
1145 case spirv::Opcode::OpTypeArray:
1147 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1149 case spirv::Opcode::OpTypeFunction:
1151 case spirv::Opcode::OpTypeImage:
1153 case spirv::Opcode::OpTypeSampledImage:
1155 case spirv::Opcode::OpTypeRuntimeArray:
1157 case spirv::Opcode::OpTypeStruct:
1159 case spirv::Opcode::OpTypeMatrix:
1161 case spirv::Opcode::OpTypeTensorARM:
1163 case spirv::Opcode::OpTypeGraphARM:
1166 return emitError(unknownLoc,
"unhandled type instruction");
1173 if (operands.size() != 3)
1174 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1176 auto pointeeType =
getType(operands[2]);
1178 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1181 uint32_t typePointerID = operands[0];
1182 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1185 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1186 deferredStructIt != std::end(deferredStructTypesInfos);) {
1187 for (
auto *unresolvedMemberIt =
1188 std::begin(deferredStructIt->unresolvedMemberTypes);
1189 unresolvedMemberIt !=
1190 std::end(deferredStructIt->unresolvedMemberTypes);) {
1191 if (unresolvedMemberIt->first == typePointerID) {
1195 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1196 typeMap[typePointerID];
1197 unresolvedMemberIt =
1198 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1200 ++unresolvedMemberIt;
1204 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1206 auto structType = deferredStructIt->deferredStructType;
1208 assert(structType &&
"expected a spirv::StructType");
1209 assert(structType.isIdentified() &&
"expected an indentified struct");
1211 if (failed(structType.trySetBody(
1212 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1213 deferredStructIt->memberDecorationsInfo,
1214 deferredStructIt->structDecorationsInfo)))
1217 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1228 if (operands.size() != 3) {
1230 "OpTypeArray must have element type and count parameters");
1235 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1243 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1244 << operands[2] <<
"can only come from normal constant right now";
1247 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1248 count = intVal.getValue().getZExtValue();
1250 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1251 "scalar integer constant instruction");
1255 elementTy, count, typeDecorations.lookup(operands[0]));
1261 assert(!operands.empty() &&
"No operands for processing function type");
1262 if (operands.size() == 1) {
1263 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1265 auto returnType =
getType(operands[1]);
1267 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1270 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1271 auto ty =
getType(operands[i]);
1273 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1275 argTypes.push_back(ty);
1281 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1287 if (operands.size() != 6) {
1289 "OpTypeCooperativeMatrixKHR must have element type, "
1290 "scope, row and column parameters, and use");
1296 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1300 std::optional<spirv::Scope> scope =
1305 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1314 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1315 "undefined constant <id> ")
1319 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1320 "references undefined constant <id> ")
1324 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1325 "undefined constant <id> ")
1328 unsigned rows = rowsAttr.getInt();
1329 unsigned columns = columnsAttr.getInt();
1331 std::optional<spirv::CooperativeMatrixUseKHR> use =
1332 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1336 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1340 typeMap[operands[0]] =
1347 if (operands.size() != 2) {
1348 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1353 "OpTypeRuntimeArray references undefined <id> ")
1357 memberType, typeDecorations.lookup(operands[0]));
1365 if (operands.empty()) {
1366 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1369 if (operands.size() == 1) {
1371 typeMap[operands[0]] =
1380 for (
auto op : llvm::drop_begin(operands, 1)) {
1382 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1384 if (!memberType && !typeForwardPtr)
1385 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1389 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1391 memberTypes.push_back(memberType);
1396 if (memberDecorationMap.count(operands[0])) {
1397 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1398 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1399 if (allMemberDecorations.count(memberIndex)) {
1400 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1402 if (memberDecoration.first == spirv::Decoration::Offset) {
1404 if (offsetInfo.empty()) {
1405 offsetInfo.resize(memberTypes.size());
1407 offsetInfo[memberIndex] = memberDecoration.second[0];
1409 auto intType = mlir::IntegerType::get(context, 32);
1410 if (!memberDecoration.second.empty()) {
1411 memberDecorationsInfo.emplace_back(
1412 memberIndex, memberDecoration.first,
1413 IntegerAttr::get(intType, memberDecoration.second[0]));
1415 memberDecorationsInfo.emplace_back(
1416 memberIndex, memberDecoration.first, UnitAttr::get(context));
1425 if (decorations.count(operands[0])) {
1428 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1429 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1430 assert(decoration.has_value());
1431 structDecorationsInfo.emplace_back(decoration.value(),
1432 decorationAttr.getValue());
1436 uint32_t structID = operands[0];
1437 std::string structIdentifier = nameMap.lookup(structID).str();
1439 if (structIdentifier.empty()) {
1440 assert(unresolvedMemberTypes.empty() &&
1441 "didn't expect unresolved member types");
1443 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1446 typeMap[structID] = structTy;
1448 if (!unresolvedMemberTypes.empty())
1449 deferredStructTypesInfos.push_back(
1450 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1451 memberDecorationsInfo, structDecorationsInfo});
1452 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1453 memberDecorationsInfo,
1454 structDecorationsInfo)))
1465 if (operands.size() != 3) {
1467 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1468 " (result_id, column_type, and column_count)");
1474 "OpTypeMatrix references undefined column type.")
1478 uint32_t colsCount = operands[2];
1485 unsigned size = operands.size();
1486 if (size < 2 || size > 4)
1487 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1488 "(result_id, element_type, (rank), (shape)) ")
1494 "OpTypeTensorARM references undefined element type ")
1504 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1505 "scalar integer constant instruction");
1506 unsigned rank = rankAttr.getValue().getZExtValue();
1513 std::optional<std::pair<Attribute, Type>> shapeInfo =
1516 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1517 "constant instruction of type OpTypeArray");
1519 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1521 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1522 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1524 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1526 shape.push_back(dimIntAttr.getValue().getSExtValue());
1534 unsigned size = operands.size();
1536 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1537 "(result_id, num_inputs, (inout0_type, "
1538 "inout1_type, ...))")
1541 uint32_t numInputs = operands[1];
1544 for (
unsigned i = 2; i < size; ++i) {
1548 "OpTypeGraphARM references undefined element type.")
1551 if (i - 2 >= numInputs) {
1552 returnTypes.push_back(inOutTy);
1554 argTypes.push_back(inOutTy);
1557 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1563 if (operands.size() != 2)
1565 "OpTypeForwardPointer instruction must have two operands");
1567 typeForwardPointerIDs.insert(operands[0]);
1577 if (operands.size() != 8)
1580 "OpTypeImage with non-eight operands are not supported yet");
1584 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1587 auto dim = spirv::symbolizeDim(operands[2]);
1589 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1592 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1594 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1597 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1599 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1602 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1604 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1606 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1607 if (!samplerUseInfo)
1608 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1611 auto format = spirv::symbolizeImageFormat(operands[7]);
1613 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1617 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1618 samplingInfo.value(), samplerUseInfo.value(), format.value());
1624 if (operands.size() != 2)
1625 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1630 "OpTypeSampledImage references undefined <id>: ")
1643 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1645 if (operands.size() < 2) {
1647 << opname <<
" must have type <id> and result <id>";
1649 if (operands.size() < 3) {
1651 << opname <<
" must have at least 1 more parameter";
1656 return emitError(unknownLoc,
"undefined result type from <id> ")
1660 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1661 if (bitwidth == 64) {
1662 if (operands.size() == 4) {
1666 << opname <<
" should have 2 parameters for 64-bit values";
1668 if (bitwidth <= 32) {
1669 if (operands.size() == 3) {
1675 <<
" should have 1 parameter for values with no more than 32 bits";
1677 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1681 auto resultID = operands[1];
1683 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1684 auto bitwidth = intType.getWidth();
1685 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1690 if (bitwidth == 64) {
1697 } words = {operands[2], operands[3]};
1698 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1699 }
else if (bitwidth <= 32) {
1700 value = APInt(bitwidth, operands[2],
true,
1704 auto attr = opBuilder.getIntegerAttr(intType, value);
1711 constantMap.try_emplace(resultID, attr, intType);
1717 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1718 auto bitwidth = floatType.getWidth();
1719 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1724 if (floatType.isF64()) {
1731 } words = {operands[2], operands[3]};
1732 value = APFloat(llvm::bit_cast<double>(words));
1733 }
else if (floatType.isF32()) {
1734 value = APFloat(llvm::bit_cast<float>(operands[2]));
1735 }
else if (floatType.isF16()) {
1736 APInt data(16, operands[2]);
1737 value = APFloat(APFloat::IEEEhalf(), data);
1738 }
else if (floatType.isBF16()) {
1739 APInt data(16, operands[2]);
1740 value = APFloat(APFloat::BFloat(), data);
1741 }
else if (floatType.isF8E4M3FN()) {
1742 APInt data(8, operands[2]);
1743 value = APFloat(APFloat::Float8E4M3FN(), data);
1744 }
else if (floatType.isF8E5M2()) {
1745 APInt data(8, operands[2]);
1746 value = APFloat(APFloat::Float8E5M2(), data);
1749 auto attr = opBuilder.getFloatAttr(floatType, value);
1755 constantMap.try_emplace(resultID, attr, floatType);
1761 return emitError(unknownLoc,
"OpConstant can only generate values of "
1762 "scalar integer or floating-point type");
1767 if (operands.size() != 2) {
1769 << (isSpec ?
"Spec" :
"") <<
"Constant"
1770 << (isTrue ?
"True" :
"False")
1771 <<
" must have type <id> and result <id>";
1774 auto attr = opBuilder.getBoolAttr(isTrue);
1775 auto resultID = operands[1];
1781 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1789 if (operands.size() < 2) {
1791 "OpConstantComposite must have type <id> and result <id>");
1793 if (operands.size() < 3) {
1795 "OpConstantComposite must have at least 1 parameter");
1800 return emitError(unknownLoc,
"undefined result type from <id> ")
1805 elements.reserve(operands.size() - 2);
1806 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1809 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1810 << operands[i] <<
" must come from a normal constant";
1812 elements.push_back(elementInfo->first);
1815 auto resultID = operands[1];
1816 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1819 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1820 for (
auto value : denseElemAttr.getValues<
Attribute>())
1821 flattenedElems.push_back(value);
1823 flattenedElems.push_back(element);
1827 constantMap.try_emplace(resultID, attr, tensorType);
1828 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1832 constantMap.try_emplace(resultID, attr, shapedType);
1833 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1834 auto attr = opBuilder.getArrayAttr(elements);
1835 constantMap.try_emplace(resultID, attr, resultType);
1837 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1846 if (operands.size() != 3) {
1849 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1855 return emitError(unknownLoc,
"undefined result type from <id> ")
1859 auto compositeType = dyn_cast<CompositeType>(resultType);
1860 if (!compositeType) {
1862 "result type from <id> is not a composite type")
1866 uint32_t resultID = operands[1];
1867 uint32_t constantID = operands[2];
1869 std::optional<std::pair<Attribute, Type>> constantInfo =
1871 if (constantInfo.has_value()) {
1872 constantCompositeReplicateMap.try_emplace(
1873 resultID, constantInfo.value().first, resultType);
1877 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1879 if (replicatedConstantCompositeInfo.has_value()) {
1880 constantCompositeReplicateMap.try_emplace(
1881 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1885 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1887 <<
" must come from a normal constant or a "
1888 "OpConstantCompositeReplicateEXT";
1893 if (operands.size() < 2) {
1896 "OpSpecConstantComposite must have type <id> and result <id>");
1898 if (operands.size() < 3) {
1900 "OpSpecConstantComposite must have at least 1 parameter");
1905 return emitError(unknownLoc,
"undefined result type from <id> ")
1909 auto resultID = operands[1];
1913 elements.reserve(operands.size() - 2);
1914 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1916 elements.push_back(SymbolRefAttr::get(elementInfo));
1919 auto op = spirv::SpecConstantCompositeOp::create(
1920 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1921 opBuilder.getArrayAttr(elements));
1922 specConstCompositeMap[resultID] = op;
1929 if (operands.size() != 3) {
1930 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1931 "3 operands but found ")
1937 return emitError(unknownLoc,
"undefined result type from <id> ")
1941 auto compositeType = dyn_cast<CompositeType>(resultType);
1942 if (!compositeType) {
1944 "result type from <id> is not a composite type")
1948 uint32_t resultID = operands[1];
1951 spirv::SpecConstantOp constituentSpecConstantOp =
1953 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1954 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1955 SymbolRefAttr::get(constituentSpecConstantOp));
1957 specConstCompositeReplicateMap[resultID] = op;
1964 if (operands.size() < 3)
1965 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1966 "result <id>, and operand opcode");
1968 uint32_t resultTypeID = operands[0];
1971 return emitError(unknownLoc,
"undefined result type from <id> ")
1974 uint32_t resultID = operands[1];
1975 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1976 auto emplaceResult = specConstOperationMap.try_emplace(
1979 enclosedOpcode, resultTypeID,
1982 if (!emplaceResult.second)
1983 return emitError(unknownLoc,
"value with <id>: ")
1984 << resultID <<
" is probably defined before.";
1990 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2006 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2007 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
2010 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2011 enclosedOpResultTypeAndOperands.push_back(fakeID);
2012 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2013 enclosedOpOperands.end());
2028 auto specConstOperationOp =
2029 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2031 Region &body = specConstOperationOp.getBody();
2033 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2040 opBuilder.setInsertionPointToEnd(&block);
2042 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2043 return specConstOperationOp.getResult();
2048 if (operands.size() != 2) {
2050 "OpConstantNull must only have type <id> and result <id>");
2055 return emitError(unknownLoc,
"undefined result type from <id> ")
2059 auto resultID = operands[1];
2061 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2062 attr = opBuilder.getZeroAttr(resultType);
2063 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2064 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2071 constantMap.try_emplace(resultID, attr, resultType);
2075 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2081 if (operands.size() < 3) {
2083 <<
"OpGraphConstantARM must have at least 2 operands";
2088 return emitError(unknownLoc,
"undefined result type from <id> ")
2092 uint32_t resultID = operands[1];
2094 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2095 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2098 APInt graph_constant_id = APInt(32, operands[2],
true);
2099 Type i32Ty = opBuilder.getIntegerType(32);
2100 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2101 graphConstantMap.try_emplace(
2113 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2114 <<
" @ " << block <<
"\n");
2121 auto *block = curFunction->addBlock();
2122 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2123 <<
" @ " << block <<
"\n");
2124 return blockMap[id] = block;
2129 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2132 if (operands.size() != 1) {
2133 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2141 spirv::BranchOp::create(opBuilder, loc,
target);
2151 "OpBranchConditional must appear inside a block");
2154 if (operands.size() != 3 && operands.size() != 5) {
2156 "OpBranchConditional must have condition, true label, "
2157 "false label, and optionally two branch weights");
2160 auto condition =
getValue(operands[0]);
2164 std::optional<std::pair<uint32_t, uint32_t>> weights;
2165 if (operands.size() == 5) {
2166 weights = std::make_pair(operands[3], operands[4]);
2172 spirv::BranchConditionalOp::create(
2173 opBuilder, loc, condition, trueBlock,
2183 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2186 if (operands.size() != 1) {
2187 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2190 auto labelID = operands[0];
2193 LLVM_DEBUG(logger.startLine()
2194 <<
"[block] populating block " << block <<
"\n");
2196 assert(block->empty() &&
"re-deserialize the same block!");
2198 opBuilder.setInsertionPointToStart(block);
2199 blockMap[labelID] = curBlock = block;
2206 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2211 LLVM_DEBUG(logger.startLine()
2212 <<
"[block] populating block " << block <<
"\n");
2214 assert(block->
empty() &&
"re-deserialize the same block!");
2216 opBuilder.setInsertionPointToStart(block);
2217 blockMap[graphID] = curBlock = block;
2225 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2228 if (operands.size() < 2) {
2231 "OpSelectionMerge must specify merge target and selection control");
2236 auto selectionControl = operands[1];
2238 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2242 "a block cannot have more than one OpSelectionMerge instruction");
2251 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2254 if (operands.size() < 3) {
2255 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2256 "continue target and loop control");
2262 uint32_t loopControl = operands[2];
2265 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2269 "a block cannot have more than one OpLoopMerge instruction");
2277 return emitError(unknownLoc,
"OpPhi must appear in a block");
2280 if (operands.size() < 4) {
2281 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2282 "and variable-parent pairs");
2287 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2288 valueMap[operands[1]] = blockArg;
2289 LLVM_DEBUG(logger.startLine()
2290 <<
"[phi] created block argument " << blockArg
2291 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2295 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2296 uint32_t value = operands[i];
2298 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2299 blockPhiInfo[predecessorTargetPair].push_back(value);
2300 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2301 <<
" with arg id = " << value <<
"\n");
2309 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2311 if (operands.size() < 2)
2312 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2313 "a default target");
2315 if (operands.size() % 2)
2317 "OpSwitch must at have an even number of operands: "
2318 "selector, default target and any number of literal and "
2319 "label <id> pairs");
2327 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2328 literals.push_back(operands[i]);
2333 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2342class ControlFlowStructurizer {
2345 ControlFlowStructurizer(
Location loc, uint32_t control,
2348 llvm::ScopedPrinter &logger)
2349 : location(loc), control(control), blockMergeInfo(mergeInfo),
2350 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2353 ControlFlowStructurizer(
Location loc, uint32_t control,
2356 : location(loc), control(control), blockMergeInfo(mergeInfo),
2357 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2367 LogicalResult structurize();
2372 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2375 spirv::LoopOp createLoopOp(uint32_t loopControl);
2378 void collectBlocksInConstruct();
2387 Block *continueBlock;
2393 llvm::ScopedPrinter &logger;
2399ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2402 OpBuilder builder(&mergeBlock->front());
2404 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2405 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2406 selectionOp.addMergeBlock(builder);
2411spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2414 OpBuilder builder(&mergeBlock->front());
2416 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2417 auto loopOp = spirv::LoopOp::create(builder, location, control);
2418 loopOp.addEntryAndMergeBlock(builder);
2423void ControlFlowStructurizer::collectBlocksInConstruct() {
2424 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2427 constructBlocks.insert(headerBlock);
2431 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2432 for (
auto *successor : constructBlocks[i]->getSuccessors())
2433 if (successor != mergeBlock)
2434 constructBlocks.insert(successor);
2438LogicalResult ControlFlowStructurizer::structurize() {
2439 Operation *op =
nullptr;
2440 bool isLoop = continueBlock !=
nullptr;
2442 if (
auto loopOp = createLoopOp(control))
2443 op = loopOp.getOperation();
2445 if (
auto selectionOp = createSelectionOp(control))
2446 op = selectionOp.getOperation();
2455 mapper.
map(mergeBlock, &body.
back());
2457 collectBlocksInConstruct();
2478 OpBuilder builder(body);
2479 for (
auto *block : constructBlocks) {
2482 auto *newBlock = builder.createBlock(&body.
back());
2483 mapper.
map(block, newBlock);
2484 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2485 <<
" from block " << block <<
"\n");
2487 for (BlockArgument blockArg : block->getArguments()) {
2489 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2490 mapper.
map(blockArg, newArg);
2491 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2492 << blockArg <<
" to " << newArg <<
"\n");
2495 LLVM_DEBUG(logger.startLine()
2496 <<
"[cf] block " << block <<
" is a function entry block\n");
2499 for (
auto &op : *block)
2500 newBlock->push_back(op.
clone(mapper));
2504 auto remapOperands = [&](Operation *op) {
2506 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2507 operand.set(mappedOp);
2510 succOp.set(mappedOp);
2512 for (
auto &block : body)
2513 block.walk(remapOperands);
2521 headerBlock->replaceAllUsesWith(mergeBlock);
2524 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2525 headerBlock->getParentOp()->print(logger.getOStream());
2526 logger.startLine() <<
"\n";
2530 if (!mergeBlock->args_empty()) {
2531 return mergeBlock->getParentOp()->emitError(
2532 "OpPhi in loop merge block unsupported");
2538 for (BlockArgument blockArg : headerBlock->getArguments())
2539 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2543 SmallVector<Value, 4> blockArgs;
2544 if (!headerBlock->args_empty())
2545 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2549 builder.setInsertionPointToEnd(&body.front());
2550 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2551 ArrayRef<Value>(blockArgs));
2556 SmallVector<Value> valuesToYield;
2559 SmallVector<Value> outsideUses;
2573 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2578 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2579 valuesToYield.push_back(body.back().getArguments().back());
2580 outsideUses.push_back(blockArg);
2585 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2588 for (
auto *block : constructBlocks)
2589 block->dropAllReferences();
2594 for (
Block *block : constructBlocks) {
2595 for (Operation &op : *block) {
2599 outsideUses.push_back(
result);
2602 for (BlockArgument &arg : block->getArguments()) {
2603 if (!arg.use_empty()) {
2605 outsideUses.push_back(arg);
2610 assert(valuesToYield.size() == outsideUses.size());
2614 if (!valuesToYield.empty()) {
2615 LLVM_DEBUG(logger.startLine()
2616 <<
"[cf] yielding values from the selection / loop region\n");
2619 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2620 Operation *merge = llvm::getSingleElement(mergeOps);
2622 merge->setOperands(valuesToYield);
2630 builder.setInsertionPoint(&mergeBlock->front());
2632 Operation *newOp =
nullptr;
2635 newOp = spirv::LoopOp::create(builder, location,
2637 static_cast<spirv::LoopControl
>(control));
2639 newOp = spirv::SelectionOp::create(
2641 static_cast<spirv::SelectionControl
>(control));
2651 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2652 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2658 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2665 for (
auto *block : constructBlocks) {
2666 if (!block->use_empty())
2667 return emitError(block->getParent()->getLoc(),
2668 "failed control flow structurization: "
2669 "block has uses outside of the "
2670 "enclosing selection/loop construct");
2671 for (Operation &op : *block)
2673 return op.
emitOpError(
"failed control flow structurization: value has "
2674 "uses outside of the "
2675 "enclosing selection/loop construct");
2676 for (BlockArgument &arg : block->getArguments())
2677 if (!arg.use_empty())
2678 return emitError(arg.getLoc(),
"failed control flow structurization: "
2679 "block argument has uses outside of the "
2680 "enclosing selection/loop construct");
2684 for (
auto *block : constructBlocks) {
2724 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2725 auto it = blockMergeInfo.find(block);
2726 if (it != blockMergeInfo.end()) {
2728 Location loc = it->second.loc;
2732 return emitError(loc,
"failed control flow structurization: nested "
2733 "loop header block should be remapped!");
2735 Block *newContinue = it->second.continueBlock;
2739 return emitError(loc,
"failed control flow structurization: nested "
2740 "loop continue block should be remapped!");
2743 Block *newMerge = it->second.mergeBlock;
2745 newMerge = mappedTo;
2749 blockMergeInfo.
erase(it);
2750 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2757 if (block->walk(updateMergeInfo).wasInterrupted())
2765 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2766 <<
" to only contain a spirv.Branch op\n");
2770 builder.setInsertionPointToEnd(block);
2771 spirv::BranchOp::create(builder, location, mergeBlock);
2773 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2778 LLVM_DEBUG(logger.startLine()
2779 <<
"[cf] after structurizing construct with header block "
2780 << headerBlock <<
":\n"
2789 <<
"//----- [phi] start wiring up block arguments -----//\n";
2795 for (
const auto &info : blockPhiInfo) {
2796 Block *block = info.first.first;
2800 logger.startLine() <<
"[phi] block " << block <<
"\n";
2801 logger.startLine() <<
"[phi] before creating block argument:\n";
2803 logger.startLine() <<
"\n";
2809 opBuilder.setInsertionPoint(op);
2812 blockArgs.reserve(phiInfo.size());
2813 for (uint32_t valueId : phiInfo) {
2815 blockArgs.push_back(value);
2816 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2817 <<
" id = " << valueId <<
"\n");
2819 return emitError(unknownLoc,
"OpPhi references undefined value!");
2823 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2825 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2826 branchOp.getTarget(), blockArgs);
2828 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2829 assert((branchCondOp.getTrueBlock() ==
target ||
2830 branchCondOp.getFalseBlock() ==
target) &&
2831 "expected target to be either the true or false target");
2832 if (
target == branchCondOp.getTrueTarget())
2833 spirv::BranchConditionalOp::create(
2834 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2835 blockArgs, branchCondOp.getFalseBlockArguments(),
2836 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2837 branchCondOp.getFalseTarget());
2839 spirv::BranchConditionalOp::create(
2840 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2841 branchCondOp.getTrueBlockArguments(), blockArgs,
2842 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2843 branchCondOp.getFalseBlock());
2845 branchCondOp.erase();
2846 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2847 if (
target == switchOp.getDefaultTarget()) {
2851 spirv::SwitchOp::create(
2852 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2853 switchOp.getDefaultTarget(), blockArgs, literals,
2854 switchOp.getTargets(), targetOperands);
2858 auto it = llvm::find(targets,
target);
2859 assert(it != targets.end());
2860 size_t index = std::distance(targets.begin(), it);
2861 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2864 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2868 logger.startLine() <<
"[phi] after creating block argument:\n";
2870 logger.startLine() <<
"\n";
2873 blockPhiInfo.clear();
2878 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2886 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2888 auto &[block, mergeInfo] = *it;
2891 if (mergeInfo.continueBlock)
2894 if (!block->mightHaveTerminator())
2897 Operation *terminator = block->getTerminator();
2900 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2904 bool splitHeaderMergeBlock =
false;
2905 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2906 if (mergeInfo.mergeBlock == block)
2907 splitHeaderMergeBlock =
true;
2914 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2917 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2921 blockMergeInfo.erase(block);
2922 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2930 if (!options.enableControlFlowStructurization) {
2934 <<
"//----- [cf] skip structurizing control flow -----//\n";
2942 <<
"//----- [cf] start structurizing control flow -----//\n";
2947 logger.startLine() <<
"[cf] split conditional blocks\n";
2948 logger.startLine() <<
"\n";
2955 while (!blockMergeInfo.empty()) {
2956 Block *headerBlock = blockMergeInfo.
begin()->first;
2960 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2961 headerBlock->
print(logger.getOStream());
2962 logger.startLine() <<
"\n";
2966 assert(mergeBlock &&
"merge block cannot be nullptr");
2968 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2970 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2971 mergeBlock->print(logger.getOStream());
2972 logger.startLine() <<
"\n";
2976 LLVM_DEBUG(
if (continueBlock) {
2977 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2978 continueBlock->print(logger.getOStream());
2979 logger.startLine() <<
"\n";
2983 blockMergeInfo.
erase(blockMergeInfo.begin());
2984 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2985 blockMergeInfo, headerBlock,
2986 mergeBlock, continueBlock
2992 if (failed(structurizer.structurize()))
2999 <<
"//--- [cf] completed structurizing control flow ---//\n";
3012 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3013 if (fileName.empty())
3014 fileName =
"<unknown>";
3026 if (operands.size() != 3)
3027 return emitError(unknownLoc,
"OpLine must have 3 operands");
3028 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3036 if (operands.size() < 2)
3037 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3039 if (!debugInfoMap.lookup(operands[0]).empty())
3041 "duplicate debug string found for result <id> ")
3044 unsigned wordIndex = 1;
3046 if (wordIndex != operands.size())
3048 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.