23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecorate with ")
238 << decorationName <<
" needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::Location:
289 case spirv::Decoration::SpecId:
290 case spirv::Decoration::Index:
291 case spirv::Decoration::Offset:
292 case spirv::Decoration::XfbBuffer:
293 case spirv::Decoration::XfbStride:
294 if (words.size() != 3) {
295 return emitError(unknownLoc,
"OpDecorate with ")
296 << decorationName <<
" needs a single integer literal";
298 decorations[words[0]].set(
299 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
301 case spirv::Decoration::BuiltIn:
302 if (words.size() != 3) {
303 return emitError(unknownLoc,
"OpDecorate with ")
304 << decorationName <<
" needs a single integer literal";
306 decorations[words[0]].set(
307 symbol, opBuilder.getStringAttr(
308 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
310 case spirv::Decoration::ArrayStride:
311 if (words.size() != 3) {
312 return emitError(unknownLoc,
"OpDecorate with ")
313 << decorationName <<
" needs a single integer literal";
315 typeDecorations[words[0]] = words[2];
317 case spirv::Decoration::LinkageAttributes: {
318 if (words.size() < 4) {
319 return emitError(unknownLoc,
"OpDecorate with ")
321 <<
" needs at least 1 string and 1 integer literal";
329 unsigned wordIndex = 2;
331 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
332 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
333 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
334 StringAttr::get(context, linkageName), linkageTypeAttr);
335 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
338 case spirv::Decoration::Aliased:
339 case spirv::Decoration::AliasedPointer:
340 case spirv::Decoration::Block:
341 case spirv::Decoration::BufferBlock:
342 case spirv::Decoration::Flat:
343 case spirv::Decoration::NonReadable:
344 case spirv::Decoration::NonWritable:
345 case spirv::Decoration::NoPerspective:
346 case spirv::Decoration::NoSignedWrap:
347 case spirv::Decoration::NoUnsignedWrap:
348 case spirv::Decoration::RelaxedPrecision:
349 case spirv::Decoration::Restrict:
350 case spirv::Decoration::RestrictPointer:
351 case spirv::Decoration::NoContraction:
352 case spirv::Decoration::Constant:
353 case spirv::Decoration::Invariant:
354 case spirv::Decoration::Patch:
355 case spirv::Decoration::Coherent:
356 if (words.size() != 2) {
357 return emitError(unknownLoc,
"OpDecorate with ")
358 << decorationName <<
" needs a single target <id>";
360 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
362 case spirv::Decoration::CacheControlLoadINTEL: {
364 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
365 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
371 case spirv::Decoration::CacheControlStoreINTEL: {
373 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
374 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
381 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
387spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
389 if (words.size() < 3) {
391 "OpMemberDecorate must have at least 3 operands");
394 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
395 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
397 " missing offset specification in OpMemberDecorate with "
398 "Offset decoration");
400 ArrayRef<uint32_t> decorationOperands;
401 if (words.size() > 3) {
402 decorationOperands = words.slice(3);
404 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
409 if (words.size() < 3) {
410 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
412 unsigned wordIndex = 2;
414 if (wordIndex != words.size()) {
416 "unexpected trailing words in OpMemberName instruction");
418 memberNameMap[words[0]][words[1]] = name;
424 if (!decorations.contains(argID)) {
425 argAttrs[argIndex] = DictionaryAttr::get(context, {});
429 spirv::DecorationAttr foundDecorationAttr;
431 for (
auto decoration :
432 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
433 spirv::Decoration::AliasedPointer,
434 spirv::Decoration::RestrictPointer}) {
436 if (decAttr.getName() !=
440 if (foundDecorationAttr)
442 "more than one Aliased/Restrict decorations for "
443 "function argument with result <id> ")
446 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
451 spirv::Decoration::RelaxedPrecision))) {
456 if (foundDecorationAttr)
457 return emitError(unknownLoc,
"already found a decoration for function "
458 "argument with result <id> ")
461 foundDecorationAttr = spirv::DecorationAttr::get(
462 context, spirv::Decoration::RelaxedPrecision);
466 if (!foundDecorationAttr)
467 return emitError(unknownLoc,
"unimplemented decoration support for "
468 "function argument with result <id> ")
471 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
472 foundDecorationAttr);
473 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
480 return emitError(unknownLoc,
"found function inside function");
484 if (operands.size() != 4) {
485 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
489 return emitError(unknownLoc,
"undefined result type from <id> ")
493 uint32_t fnID = operands[1];
494 if (funcMap.count(fnID)) {
495 return emitError(unknownLoc,
"duplicate function definition/declaration");
498 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
500 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
504 if (!fnType || !isa<FunctionType>(fnType)) {
505 return emitError(unknownLoc,
"unknown function type from <id> ")
508 auto functionType = cast<FunctionType>(fnType);
510 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
511 (functionType.getNumResults() == 1 &&
512 functionType.getResult(0) != resultType)) {
513 return emitError(unknownLoc,
"mismatch in function type ")
514 << functionType <<
" and return type " << resultType <<
" specified";
518 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
519 functionType, fnControl.value());
521 if (decorations.count(fnID)) {
522 for (
auto attr : decorations[fnID].getAttrs()) {
523 funcOp->setAttr(attr.getName(), attr.getValue());
526 curFunction = funcMap[fnID] = funcOp;
527 auto *entryBlock = funcOp.addEntryBlock();
530 <<
"//===-------------------------------------------===//\n";
531 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
532 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
533 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
534 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
539 argAttrs.resize(functionType.getNumInputs());
542 if (functionType.getNumInputs()) {
543 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
544 auto argType = functionType.getInput(i);
545 spirv::Opcode opcode = spirv::Opcode::OpNop;
548 spirv::Opcode::OpFunctionParameter))) {
551 if (opcode != spirv::Opcode::OpFunctionParameter) {
554 "missing OpFunctionParameter instruction for argument ")
557 if (operands.size() != 2) {
560 "expected result type and result <id> for OpFunctionParameter");
562 auto argDefinedType =
getType(operands[0]);
563 if (!argDefinedType || argDefinedType != argType) {
565 "mismatch in argument type between function type "
567 << functionType <<
" and argument type definition "
568 << argDefinedType <<
" at argument " << i;
571 return emitError(unknownLoc,
"duplicate definition of result <id> ")
578 auto argValue = funcOp.getArgument(i);
579 valueMap[operands[1]] = argValue;
583 if (llvm::any_of(argAttrs, [](
Attribute attr) {
584 auto argAttr = cast<DictionaryAttr>(attr);
585 return !argAttr.empty();
587 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
592 auto linkageAttr = funcOp.getLinkageAttributes();
593 auto hasImportLinkage =
594 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
595 spirv::LinkageType::Import);
596 if (hasImportLinkage)
603 spirv::Opcode opcode = spirv::Opcode::OpNop;
612 spirv::Opcode::OpFunctionEnd))) {
615 if (opcode == spirv::Opcode::OpFunctionEnd) {
618 if (opcode != spirv::Opcode::OpLabel) {
619 return emitError(unknownLoc,
"a basic block must start with OpLabel");
621 if (instOperands.size() != 1) {
622 return emitError(unknownLoc,
"OpLabel should only have result <id>");
624 blockMap[instOperands[0]] = entryBlock;
632 spirv::Opcode::OpFunctionEnd)) &&
633 opcode != spirv::Opcode::OpFunctionEnd) {
638 if (opcode != spirv::Opcode::OpFunctionEnd) {
648 if (!operands.empty()) {
649 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
660 curFunction = std::nullopt;
665 <<
"//===-------------------------------------------===//\n";
672 if (operands.size() < 2) {
674 "missing graph defintion in OpGraphEntryPointARM");
677 unsigned wordIndex = 0;
678 uint32_t graphID = operands[wordIndex++];
679 if (!graphMap.contains(graphID)) {
681 "missing graph definition/declaration with id ")
685 spirv::GraphARMOp graphARM = graphMap[graphID];
687 graphARM.setSymName(name);
688 graphARM.setEntryPoint(
true);
691 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
693 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
695 return emitError(unknownLoc,
"undefined result <id> ")
696 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
702 opBuilder.setInsertionPoint(graphARM);
703 spirv::GraphEntryPointARMOp::create(
704 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
705 opBuilder.getArrayAttr(interface));
713 return emitError(unknownLoc,
"found graph inside graph");
716 if (operands.size() < 2) {
717 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
721 if (!type || !isa<GraphType>(type)) {
722 return emitError(unknownLoc,
"unknown graph type from <id> ")
725 auto graphType = cast<GraphType>(type);
726 if (graphType.getNumResults() <= 0) {
727 return emitError(unknownLoc,
"expected at least one result");
730 uint32_t graphID = operands[1];
731 if (graphMap.count(graphID)) {
732 return emitError(unknownLoc,
"duplicate graph definition/declaration");
737 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
738 curGraph = graphMap[graphID] = graphOp;
739 Block *entryBlock = graphOp.addEntryBlock();
742 <<
"//===-------------------------------------------===//\n";
743 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
744 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
745 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
746 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
751 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
752 spirv::Opcode opcode;
755 spirv::Opcode::OpGraphInputARM))) {
758 if (operands.size() != 3) {
759 return emitError(unknownLoc,
"expected result type, result <id> and "
760 "input index for OpGraphInputARM");
764 if (!argDefinedType) {
765 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
768 if (argDefinedType != argType) {
770 "mismatch in argument type between graph type "
772 << graphType <<
" and argument type definition " << argDefinedType
773 <<
" at argument " <<
index;
776 return emitError(unknownLoc,
"duplicate definition of result <id> ")
781 if (!inputIndexAttr) {
783 "unable to read inputIndex value from constant op ")
786 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
787 valueMap[operands[1]] = argValue;
790 graphOutputs.resize(graphType.getNumResults());
796 blockMap[graphID] = entryBlock;
803 spirv::Opcode opcode;
813 }
while (opcode != spirv::Opcode::OpGraphEndARM);
820 if (operands.size() != 2) {
823 "expected value id and output index for OpGraphSetOutputARM");
826 uint32_t
id = operands[0];
829 return emitError(unknownLoc,
"could not find result <id> ") << id;
833 if (!outputIndexAttr) {
835 "unable to read outputIndex value from constant op ")
838 graphOutputs[outputIndexAttr.getInt()] = value;
845 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
848 if (!operands.empty()) {
849 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
853 curGraph = std::nullopt;
854 graphOutputs.clear();
859 <<
"//===-------------------------------------------===//\n";
864std::optional<std::pair<Attribute, Type>>
866 auto constIt = constantMap.find(
id);
867 if (constIt == constantMap.end())
869 return constIt->getSecond();
872std::optional<std::pair<Attribute, Type>>
874 if (
auto it = constantCompositeReplicateMap.find(
id);
875 it != constantCompositeReplicateMap.end())
880std::optional<spirv::SpecConstOperationMaterializationInfo>
882 auto constIt = specConstOperationMap.find(
id);
883 if (constIt == specConstOperationMap.end())
885 return constIt->getSecond();
889 auto funcName = nameMap.lookup(
id).str();
890 if (funcName.empty()) {
891 funcName =
"spirv_fn_" + std::to_string(
id);
897 std::string graphName = nameMap.lookup(
id).str();
898 if (graphName.empty()) {
899 graphName =
"spirv_graph_" + std::to_string(
id);
905 auto constName = nameMap.lookup(
id).str();
906 if (constName.empty()) {
907 constName =
"spirv_spec_const_" + std::to_string(
id);
914 TypedAttr defaultValue) {
916 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
918 if (decorations.count(resultID)) {
919 for (
auto attr : decorations[resultID].getAttrs())
920 op->setAttr(attr.getName(), attr.getValue());
922 specConstMap[resultID] = op;
926std::optional<spirv::GraphConstantARMOpMaterializationInfo>
928 auto graphConstIt = graphConstantMap.find(
id);
929 if (graphConstIt == graphConstantMap.end())
931 return graphConstIt->getSecond();
936 unsigned wordIndex = 0;
937 if (operands.size() < 3) {
940 "OpVariable needs at least 3 operands, type, <id> and storage class");
944 auto type =
getType(operands[wordIndex]);
946 return emitError(unknownLoc,
"unknown result type <id> : ")
947 << operands[wordIndex];
949 auto ptrType = dyn_cast<spirv::PointerType>(type);
952 "expected a result type <id> to be a spirv.ptr, found : ")
958 auto variableID = operands[wordIndex];
959 auto variableName = nameMap.lookup(variableID).str();
960 if (variableName.empty()) {
961 variableName =
"spirv_var_" + std::to_string(variableID);
966 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
967 if (ptrType.getStorageClass() != storageClass) {
968 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
969 << type <<
" and that specified in OpVariable instruction : "
970 << stringifyStorageClass(storageClass);
977 if (wordIndex < operands.size()) {
987 return emitError(unknownLoc,
"unknown <id> ")
988 << operands[wordIndex] <<
"used as initializer";
990 initializer = SymbolRefAttr::get(op);
993 if (wordIndex != operands.size()) {
995 "found more operands than expected when deserializing "
996 "OpVariable instruction, only ")
997 << wordIndex <<
" of " << operands.size() <<
" processed";
1000 auto varOp = spirv::GlobalVariableOp::create(
1001 opBuilder, loc, TypeAttr::get(type),
1002 opBuilder.getStringAttr(variableName), initializer);
1005 if (decorations.count(variableID)) {
1006 for (
auto attr : decorations[variableID].getAttrs())
1007 varOp->setAttr(attr.getName(), attr.getValue());
1009 globalVariableMap[variableID] = varOp;
1018 return dyn_cast<IntegerAttr>(constInfo->first);
1022 if (operands.size() < 2) {
1023 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1026 unsigned wordIndex = 1;
1028 if (wordIndex != operands.size()) {
1030 "unexpected trailing words in OpName instruction");
1035 nameMap.emplace_or_assign(operands[0], name);
1046 if (operands.empty()) {
1047 return emitError(unknownLoc,
"type instruction with opcode ")
1048 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1053 if (typeMap.count(operands[0])) {
1054 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1059 case spirv::Opcode::OpTypeVoid:
1060 if (operands.size() != 1)
1061 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1062 typeMap[operands[0]] = opBuilder.getNoneType();
1064 case spirv::Opcode::OpTypeBool:
1065 if (operands.size() != 1)
1066 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1067 typeMap[operands[0]] = opBuilder.getI1Type();
1069 case spirv::Opcode::OpTypeInt: {
1070 if (operands.size() != 3)
1072 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1081 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1082 : IntegerType::SignednessSemantics::Signless;
1083 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1085 case spirv::Opcode::OpTypeFloat: {
1086 if (operands.size() != 2 && operands.size() != 3)
1088 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1089 "or 3 operands (type, bitwidth, encoding), but got ")
1091 uint32_t bitWidth = operands[1];
1094 if (operands.size() == 2) {
1097 floatTy = opBuilder.getF16Type();
1100 floatTy = opBuilder.getF32Type();
1103 floatTy = opBuilder.getF64Type();
1106 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1111 if (operands.size() == 3) {
1112 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1114 floatTy = opBuilder.getBF16Type();
1115 else if (spirv::FPEncoding(operands[2]) ==
1116 spirv::FPEncoding::Float8E4M3EXT &&
1118 floatTy = opBuilder.getF8E4M3FNType();
1119 else if (spirv::FPEncoding(operands[2]) ==
1120 spirv::FPEncoding::Float8E5M2EXT &&
1122 floatTy = opBuilder.getF8E5M2Type();
1124 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1125 << operands[2] <<
" and bitWidth " << bitWidth;
1128 typeMap[operands[0]] = floatTy;
1130 case spirv::Opcode::OpTypeVector: {
1131 if (operands.size() != 3) {
1134 "OpTypeVector must have element type and count parameters");
1138 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1141 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1143 case spirv::Opcode::OpTypePointer: {
1146 case spirv::Opcode::OpTypeArray:
1148 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1150 case spirv::Opcode::OpTypeFunction:
1152 case spirv::Opcode::OpTypeImage:
1154 case spirv::Opcode::OpTypeSampler:
1156 case spirv::Opcode::OpTypeSampledImage:
1158 case spirv::Opcode::OpTypeRuntimeArray:
1160 case spirv::Opcode::OpTypeStruct:
1162 case spirv::Opcode::OpTypeMatrix:
1164 case spirv::Opcode::OpTypeTensorARM:
1166 case spirv::Opcode::OpTypeGraphARM:
1169 return emitError(unknownLoc,
"unhandled type instruction");
1176 if (operands.size() != 3)
1177 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1179 auto pointeeType =
getType(operands[2]);
1181 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1184 uint32_t typePointerID = operands[0];
1185 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1188 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1189 deferredStructIt != std::end(deferredStructTypesInfos);) {
1190 for (
auto *unresolvedMemberIt =
1191 std::begin(deferredStructIt->unresolvedMemberTypes);
1192 unresolvedMemberIt !=
1193 std::end(deferredStructIt->unresolvedMemberTypes);) {
1194 if (unresolvedMemberIt->first == typePointerID) {
1198 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1199 typeMap[typePointerID];
1200 unresolvedMemberIt =
1201 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1203 ++unresolvedMemberIt;
1207 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1209 auto structType = deferredStructIt->deferredStructType;
1211 assert(structType &&
"expected a spirv::StructType");
1212 assert(structType.isIdentified() &&
"expected an indentified struct");
1214 if (failed(structType.trySetBody(
1215 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1216 deferredStructIt->memberDecorationsInfo,
1217 deferredStructIt->structDecorationsInfo)))
1220 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1231 if (operands.size() != 3) {
1233 "OpTypeArray must have element type and count parameters");
1238 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1246 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1247 << operands[2] <<
"can only come from normal constant right now";
1250 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1251 count = intVal.getValue().getZExtValue();
1253 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1254 "scalar integer constant instruction");
1258 elementTy, count, typeDecorations.lookup(operands[0]));
1264 assert(!operands.empty() &&
"No operands for processing function type");
1265 if (operands.size() == 1) {
1266 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1268 auto returnType =
getType(operands[1]);
1270 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1273 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1274 auto ty =
getType(operands[i]);
1276 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1278 argTypes.push_back(ty);
1284 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1290 if (operands.size() != 6) {
1292 "OpTypeCooperativeMatrixKHR must have element type, "
1293 "scope, row and column parameters, and use");
1299 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1303 std::optional<spirv::Scope> scope =
1308 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1317 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1318 "undefined constant <id> ")
1322 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1323 "references undefined constant <id> ")
1327 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1328 "undefined constant <id> ")
1331 unsigned rows = rowsAttr.getInt();
1332 unsigned columns = columnsAttr.getInt();
1334 std::optional<spirv::CooperativeMatrixUseKHR> use =
1335 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1339 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1343 typeMap[operands[0]] =
1350 if (operands.size() != 2) {
1351 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1356 "OpTypeRuntimeArray references undefined <id> ")
1360 memberType, typeDecorations.lookup(operands[0]));
1368 if (operands.empty()) {
1369 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1372 if (operands.size() == 1) {
1374 typeMap[operands[0]] =
1383 for (
auto op : llvm::drop_begin(operands, 1)) {
1385 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1387 if (!memberType && !typeForwardPtr)
1388 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1392 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1394 memberTypes.push_back(memberType);
1399 if (memberDecorationMap.count(operands[0])) {
1400 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1401 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1402 if (allMemberDecorations.count(memberIndex)) {
1403 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1405 if (memberDecoration.first == spirv::Decoration::Offset) {
1407 if (offsetInfo.empty()) {
1408 offsetInfo.resize(memberTypes.size());
1410 offsetInfo[memberIndex] = memberDecoration.second[0];
1412 auto intType = mlir::IntegerType::get(context, 32);
1413 if (!memberDecoration.second.empty()) {
1414 memberDecorationsInfo.emplace_back(
1415 memberIndex, memberDecoration.first,
1416 IntegerAttr::get(intType, memberDecoration.second[0]));
1418 memberDecorationsInfo.emplace_back(
1419 memberIndex, memberDecoration.first, UnitAttr::get(context));
1428 if (decorations.count(operands[0])) {
1431 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1432 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1433 assert(decoration.has_value());
1434 structDecorationsInfo.emplace_back(decoration.value(),
1435 decorationAttr.getValue());
1439 uint32_t structID = operands[0];
1440 std::string structIdentifier = nameMap.lookup(structID).str();
1442 if (structIdentifier.empty()) {
1443 assert(unresolvedMemberTypes.empty() &&
1444 "didn't expect unresolved member types");
1446 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1449 typeMap[structID] = structTy;
1451 if (!unresolvedMemberTypes.empty())
1452 deferredStructTypesInfos.push_back(
1453 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1454 memberDecorationsInfo, structDecorationsInfo});
1455 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1456 memberDecorationsInfo,
1457 structDecorationsInfo)))
1468 if (operands.size() != 3) {
1470 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1471 " (result_id, column_type, and column_count)");
1477 "OpTypeMatrix references undefined column type.")
1481 uint32_t colsCount = operands[2];
1488 unsigned size = operands.size();
1489 if (size < 2 || size > 4)
1490 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1491 "(result_id, element_type, (rank), (shape)) ")
1497 "OpTypeTensorARM references undefined element type ")
1507 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1508 "scalar integer constant instruction");
1509 unsigned rank = rankAttr.getValue().getZExtValue();
1516 std::optional<std::pair<Attribute, Type>> shapeInfo =
1519 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1520 "constant instruction of type OpTypeArray");
1522 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1524 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1525 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1527 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1529 shape.push_back(dimIntAttr.getValue().getSExtValue());
1537 unsigned size = operands.size();
1539 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1540 "(result_id, num_inputs, (inout0_type, "
1541 "inout1_type, ...))")
1544 uint32_t numInputs = operands[1];
1547 for (
unsigned i = 2; i < size; ++i) {
1551 "OpTypeGraphARM references undefined element type.")
1554 if (i - 2 >= numInputs) {
1555 returnTypes.push_back(inOutTy);
1557 argTypes.push_back(inOutTy);
1560 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1566 if (operands.size() != 2)
1568 "OpTypeForwardPointer instruction must have two operands");
1570 typeForwardPointerIDs.insert(operands[0]);
1580 if (operands.size() != 8)
1583 "OpTypeImage with non-eight operands are not supported yet");
1587 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1590 auto dim = spirv::symbolizeDim(operands[2]);
1592 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1595 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1597 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1600 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1602 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1605 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1607 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1609 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1610 if (!samplerUseInfo)
1611 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1614 auto format = spirv::symbolizeImageFormat(operands[7]);
1616 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1620 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1621 samplingInfo.value(), samplerUseInfo.value(), format.value());
1627 if (operands.size() != 2)
1628 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1633 "OpTypeSampledImage references undefined <id>: ")
1642 if (operands.size() != 1)
1643 return emitError(unknownLoc,
"OpTypeSampler must have no parameters");
1655 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1657 if (operands.size() < 2) {
1659 << opname <<
" must have type <id> and result <id>";
1661 if (operands.size() < 3) {
1663 << opname <<
" must have at least 1 more parameter";
1668 return emitError(unknownLoc,
"undefined result type from <id> ")
1672 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1673 if (bitwidth == 64) {
1674 if (operands.size() == 4) {
1678 << opname <<
" should have 2 parameters for 64-bit values";
1680 if (bitwidth <= 32) {
1681 if (operands.size() == 3) {
1687 <<
" should have 1 parameter for values with no more than 32 bits";
1689 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1693 auto resultID = operands[1];
1695 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1696 auto bitwidth = intType.getWidth();
1697 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1702 if (bitwidth == 64) {
1709 } words = {operands[2], operands[3]};
1710 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1711 }
else if (bitwidth <= 32) {
1712 value = APInt(bitwidth, operands[2],
true,
1716 auto attr = opBuilder.getIntegerAttr(intType, value);
1723 constantMap.try_emplace(resultID, attr, intType);
1729 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1730 auto bitwidth = floatType.getWidth();
1731 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1736 if (floatType.isF64()) {
1743 } words = {operands[2], operands[3]};
1744 value = APFloat(llvm::bit_cast<double>(words));
1745 }
else if (floatType.isF32()) {
1746 value = APFloat(llvm::bit_cast<float>(operands[2]));
1747 }
else if (floatType.isF16()) {
1748 APInt data(16, operands[2]);
1749 value = APFloat(APFloat::IEEEhalf(), data);
1750 }
else if (floatType.isBF16()) {
1751 APInt data(16, operands[2]);
1752 value = APFloat(APFloat::BFloat(), data);
1753 }
else if (floatType.isF8E4M3FN()) {
1754 APInt data(8, operands[2]);
1755 value = APFloat(APFloat::Float8E4M3FN(), data);
1756 }
else if (floatType.isF8E5M2()) {
1757 APInt data(8, operands[2]);
1758 value = APFloat(APFloat::Float8E5M2(), data);
1761 auto attr = opBuilder.getFloatAttr(floatType, value);
1767 constantMap.try_emplace(resultID, attr, floatType);
1773 return emitError(unknownLoc,
"OpConstant can only generate values of "
1774 "scalar integer or floating-point type");
1779 if (operands.size() != 2) {
1781 << (isSpec ?
"Spec" :
"") <<
"Constant"
1782 << (isTrue ?
"True" :
"False")
1783 <<
" must have type <id> and result <id>";
1786 auto attr = opBuilder.getBoolAttr(isTrue);
1787 auto resultID = operands[1];
1793 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1801 if (operands.size() < 2) {
1803 "OpConstantComposite must have type <id> and result <id>");
1805 if (operands.size() < 3) {
1807 "OpConstantComposite must have at least 1 parameter");
1812 return emitError(unknownLoc,
"undefined result type from <id> ")
1817 elements.reserve(operands.size() - 2);
1818 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1821 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1822 << operands[i] <<
" must come from a normal constant";
1824 elements.push_back(elementInfo->first);
1827 auto resultID = operands[1];
1828 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1831 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1832 for (
auto value : denseElemAttr.getValues<
Attribute>())
1833 flattenedElems.push_back(value);
1835 flattenedElems.push_back(element);
1839 constantMap.try_emplace(resultID, attr, tensorType);
1840 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1844 constantMap.try_emplace(resultID, attr, shapedType);
1845 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1846 auto attr = opBuilder.getArrayAttr(elements);
1847 constantMap.try_emplace(resultID, attr, resultType);
1849 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1858 if (operands.size() != 3) {
1861 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1867 return emitError(unknownLoc,
"undefined result type from <id> ")
1871 auto compositeType = dyn_cast<CompositeType>(resultType);
1872 if (!compositeType) {
1874 "result type from <id> is not a composite type")
1878 uint32_t resultID = operands[1];
1879 uint32_t constantID = operands[2];
1881 std::optional<std::pair<Attribute, Type>> constantInfo =
1883 if (constantInfo.has_value()) {
1884 constantCompositeReplicateMap.try_emplace(
1885 resultID, constantInfo.value().first, resultType);
1889 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1891 if (replicatedConstantCompositeInfo.has_value()) {
1892 constantCompositeReplicateMap.try_emplace(
1893 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1897 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1899 <<
" must come from a normal constant or a "
1900 "OpConstantCompositeReplicateEXT";
1905 if (operands.size() < 2) {
1908 "OpSpecConstantComposite must have type <id> and result <id>");
1910 if (operands.size() < 3) {
1912 "OpSpecConstantComposite must have at least 1 parameter");
1917 return emitError(unknownLoc,
"undefined result type from <id> ")
1921 auto resultID = operands[1];
1925 elements.reserve(operands.size() - 2);
1926 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1928 elements.push_back(SymbolRefAttr::get(elementInfo));
1931 auto op = spirv::SpecConstantCompositeOp::create(
1932 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1933 opBuilder.getArrayAttr(elements));
1934 specConstCompositeMap[resultID] = op;
1941 if (operands.size() != 3) {
1942 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1943 "3 operands but found ")
1949 return emitError(unknownLoc,
"undefined result type from <id> ")
1953 auto compositeType = dyn_cast<CompositeType>(resultType);
1954 if (!compositeType) {
1956 "result type from <id> is not a composite type")
1960 uint32_t resultID = operands[1];
1963 spirv::SpecConstantOp constituentSpecConstantOp =
1965 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1966 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1967 SymbolRefAttr::get(constituentSpecConstantOp));
1969 specConstCompositeReplicateMap[resultID] = op;
1976 if (operands.size() < 3)
1977 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1978 "result <id>, and operand opcode");
1980 uint32_t resultTypeID = operands[0];
1983 return emitError(unknownLoc,
"undefined result type from <id> ")
1986 uint32_t resultID = operands[1];
1987 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1988 auto emplaceResult = specConstOperationMap.try_emplace(
1991 enclosedOpcode, resultTypeID,
1994 if (!emplaceResult.second)
1995 return emitError(unknownLoc,
"value with <id>: ")
1996 << resultID <<
" is probably defined before.";
2002 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2018 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2019 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
2022 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2023 enclosedOpResultTypeAndOperands.push_back(fakeID);
2024 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2025 enclosedOpOperands.end());
2040 auto specConstOperationOp =
2041 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2043 Region &body = specConstOperationOp.getBody();
2045 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2052 opBuilder.setInsertionPointToEnd(&block);
2054 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2055 return specConstOperationOp.getResult();
2060 if (operands.size() != 2) {
2062 "OpConstantNull must only have type <id> and result <id>");
2067 return emitError(unknownLoc,
"undefined result type from <id> ")
2071 auto resultID = operands[1];
2073 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2074 attr = opBuilder.getZeroAttr(resultType);
2075 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2076 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2083 constantMap.try_emplace(resultID, attr, resultType);
2087 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2093 if (operands.size() < 3) {
2095 <<
"OpGraphConstantARM must have at least 2 operands";
2100 return emitError(unknownLoc,
"undefined result type from <id> ")
2104 uint32_t resultID = operands[1];
2106 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2107 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2110 APInt graph_constant_id = APInt(32, operands[2],
true);
2111 Type i32Ty = opBuilder.getIntegerType(32);
2112 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2113 graphConstantMap.try_emplace(
2125 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2126 <<
" @ " << block <<
"\n");
2133 auto *block = curFunction->addBlock();
2134 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2135 <<
" @ " << block <<
"\n");
2136 return blockMap[id] = block;
2141 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2144 if (operands.size() != 1) {
2145 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2153 spirv::BranchOp::create(opBuilder, loc,
target);
2163 "OpBranchConditional must appear inside a block");
2166 if (operands.size() != 3 && operands.size() != 5) {
2168 "OpBranchConditional must have condition, true label, "
2169 "false label, and optionally two branch weights");
2172 auto condition =
getValue(operands[0]);
2176 std::optional<std::pair<uint32_t, uint32_t>> weights;
2177 if (operands.size() == 5) {
2178 weights = std::make_pair(operands[3], operands[4]);
2184 spirv::BranchConditionalOp::create(
2185 opBuilder, loc, condition, trueBlock,
2195 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2198 if (operands.size() != 1) {
2199 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2202 auto labelID = operands[0];
2205 LLVM_DEBUG(logger.startLine()
2206 <<
"[block] populating block " << block <<
"\n");
2208 assert(block->empty() &&
"re-deserialize the same block!");
2210 opBuilder.setInsertionPointToStart(block);
2211 blockMap[labelID] = curBlock = block;
2218 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2223 LLVM_DEBUG(logger.startLine()
2224 <<
"[block] populating block " << block <<
"\n");
2226 assert(block->
empty() &&
"re-deserialize the same block!");
2228 opBuilder.setInsertionPointToStart(block);
2229 blockMap[graphID] = curBlock = block;
2237 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2240 if (operands.size() < 2) {
2243 "OpSelectionMerge must specify merge target and selection control");
2248 auto selectionControl = operands[1];
2250 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2254 "a block cannot have more than one OpSelectionMerge instruction");
2263 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2266 if (operands.size() < 3) {
2267 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2268 "continue target and loop control");
2274 uint32_t loopControl = operands[2];
2277 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2281 "a block cannot have more than one OpLoopMerge instruction");
2289 return emitError(unknownLoc,
"OpPhi must appear in a block");
2292 if (operands.size() < 4) {
2293 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2294 "and variable-parent pairs");
2299 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2300 valueMap[operands[1]] = blockArg;
2301 LLVM_DEBUG(logger.startLine()
2302 <<
"[phi] created block argument " << blockArg
2303 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2307 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2308 uint32_t value = operands[i];
2310 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2311 blockPhiInfo[predecessorTargetPair].push_back(value);
2312 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2313 <<
" with arg id = " << value <<
"\n");
2321 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2323 if (operands.size() < 2)
2324 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2325 "a default target");
2327 if (operands.size() % 2)
2329 "OpSwitch must at have an even number of operands: "
2330 "selector, default target and any number of literal and "
2331 "label <id> pairs");
2339 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2340 literals.push_back(operands[i]);
2345 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2354class ControlFlowStructurizer {
2357 ControlFlowStructurizer(
Location loc, uint32_t control,
2360 llvm::ScopedPrinter &logger)
2361 : location(loc), control(control), blockMergeInfo(mergeInfo),
2362 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2365 ControlFlowStructurizer(
Location loc, uint32_t control,
2368 : location(loc), control(control), blockMergeInfo(mergeInfo),
2369 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2379 LogicalResult structurize();
2384 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2387 spirv::LoopOp createLoopOp(uint32_t loopControl);
2390 void collectBlocksInConstruct();
2399 Block *continueBlock;
2405 llvm::ScopedPrinter &logger;
2411ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2414 OpBuilder builder(&mergeBlock->front());
2416 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2417 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2418 selectionOp.addMergeBlock(builder);
2423spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2426 OpBuilder builder(&mergeBlock->front());
2428 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2429 auto loopOp = spirv::LoopOp::create(builder, location, control);
2430 loopOp.addEntryAndMergeBlock(builder);
2435void ControlFlowStructurizer::collectBlocksInConstruct() {
2436 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2439 constructBlocks.insert(headerBlock);
2443 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2444 for (
auto *successor : constructBlocks[i]->getSuccessors())
2445 if (successor != mergeBlock)
2446 constructBlocks.insert(successor);
2450LogicalResult ControlFlowStructurizer::structurize() {
2451 Operation *op =
nullptr;
2452 bool isLoop = continueBlock !=
nullptr;
2454 if (
auto loopOp = createLoopOp(control))
2455 op = loopOp.getOperation();
2457 if (
auto selectionOp = createSelectionOp(control))
2458 op = selectionOp.getOperation();
2467 mapper.
map(mergeBlock, &body.
back());
2469 collectBlocksInConstruct();
2490 OpBuilder builder(body);
2491 for (
auto *block : constructBlocks) {
2494 auto *newBlock = builder.createBlock(&body.
back());
2495 mapper.
map(block, newBlock);
2496 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2497 <<
" from block " << block <<
"\n");
2499 for (BlockArgument blockArg : block->getArguments()) {
2501 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2502 mapper.
map(blockArg, newArg);
2503 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2504 << blockArg <<
" to " << newArg <<
"\n");
2507 LLVM_DEBUG(logger.startLine()
2508 <<
"[cf] block " << block <<
" is a function entry block\n");
2511 for (
auto &op : *block)
2512 newBlock->push_back(op.
clone(mapper));
2516 auto remapOperands = [&](Operation *op) {
2518 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2519 operand.set(mappedOp);
2522 succOp.set(mappedOp);
2524 for (
auto &block : body)
2525 block.walk(remapOperands);
2533 headerBlock->replaceAllUsesWith(mergeBlock);
2536 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2537 headerBlock->getParentOp()->print(logger.getOStream());
2538 logger.startLine() <<
"\n";
2542 if (!mergeBlock->args_empty()) {
2543 return mergeBlock->getParentOp()->emitError(
2544 "OpPhi in loop merge block unsupported");
2550 for (BlockArgument blockArg : headerBlock->getArguments())
2551 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2555 SmallVector<Value, 4> blockArgs;
2556 if (!headerBlock->args_empty())
2557 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2561 builder.setInsertionPointToEnd(&body.front());
2562 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2563 ArrayRef<Value>(blockArgs));
2568 SmallVector<Value> valuesToYield;
2571 SmallVector<Value> outsideUses;
2585 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2590 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2591 valuesToYield.push_back(body.back().getArguments().back());
2592 outsideUses.push_back(blockArg);
2597 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2600 for (
auto *block : constructBlocks)
2601 block->dropAllReferences();
2606 for (
Block *block : constructBlocks) {
2607 for (Operation &op : *block) {
2611 outsideUses.push_back(
result);
2614 for (BlockArgument &arg : block->getArguments()) {
2615 if (!arg.use_empty()) {
2617 outsideUses.push_back(arg);
2622 assert(valuesToYield.size() == outsideUses.size());
2626 if (!valuesToYield.empty()) {
2627 LLVM_DEBUG(logger.startLine()
2628 <<
"[cf] yielding values from the selection / loop region\n");
2631 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2632 Operation *merge = llvm::getSingleElement(mergeOps);
2634 merge->setOperands(valuesToYield);
2642 builder.setInsertionPoint(&mergeBlock->front());
2644 Operation *newOp =
nullptr;
2647 newOp = spirv::LoopOp::create(builder, location,
2649 static_cast<spirv::LoopControl
>(control));
2651 newOp = spirv::SelectionOp::create(
2653 static_cast<spirv::SelectionControl
>(control));
2663 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2664 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2670 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2677 for (
auto *block : constructBlocks) {
2678 if (!block->use_empty())
2679 return emitError(block->getParent()->getLoc(),
2680 "failed control flow structurization: "
2681 "block has uses outside of the "
2682 "enclosing selection/loop construct");
2683 for (Operation &op : *block)
2685 return op.
emitOpError(
"failed control flow structurization: value has "
2686 "uses outside of the "
2687 "enclosing selection/loop construct");
2688 for (BlockArgument &arg : block->getArguments())
2689 if (!arg.use_empty())
2690 return emitError(arg.getLoc(),
"failed control flow structurization: "
2691 "block argument has uses outside of the "
2692 "enclosing selection/loop construct");
2696 for (
auto *block : constructBlocks) {
2736 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2737 auto it = blockMergeInfo.find(block);
2738 if (it != blockMergeInfo.end()) {
2740 Location loc = it->second.loc;
2744 return emitError(loc,
"failed control flow structurization: nested "
2745 "loop header block should be remapped!");
2747 Block *newContinue = it->second.continueBlock;
2751 return emitError(loc,
"failed control flow structurization: nested "
2752 "loop continue block should be remapped!");
2755 Block *newMerge = it->second.mergeBlock;
2757 newMerge = mappedTo;
2761 blockMergeInfo.
erase(it);
2762 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2769 if (block->walk(updateMergeInfo).wasInterrupted())
2777 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2778 <<
" to only contain a spirv.Branch op\n");
2782 builder.setInsertionPointToEnd(block);
2783 spirv::BranchOp::create(builder, location, mergeBlock);
2785 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2790 LLVM_DEBUG(logger.startLine()
2791 <<
"[cf] after structurizing construct with header block "
2792 << headerBlock <<
":\n"
2801 <<
"//----- [phi] start wiring up block arguments -----//\n";
2807 for (
const auto &info : blockPhiInfo) {
2808 Block *block = info.first.first;
2812 logger.startLine() <<
"[phi] block " << block <<
"\n";
2813 logger.startLine() <<
"[phi] before creating block argument:\n";
2815 logger.startLine() <<
"\n";
2821 opBuilder.setInsertionPoint(op);
2824 blockArgs.reserve(phiInfo.size());
2825 for (uint32_t valueId : phiInfo) {
2827 blockArgs.push_back(value);
2828 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2829 <<
" id = " << valueId <<
"\n");
2831 return emitError(unknownLoc,
"OpPhi references undefined value!");
2835 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2837 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2838 branchOp.getTarget(), blockArgs);
2840 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2841 assert((branchCondOp.getTrueBlock() ==
target ||
2842 branchCondOp.getFalseBlock() ==
target) &&
2843 "expected target to be either the true or false target");
2844 if (
target == branchCondOp.getTrueTarget())
2845 spirv::BranchConditionalOp::create(
2846 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2847 blockArgs, branchCondOp.getFalseBlockArguments(),
2848 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2849 branchCondOp.getFalseTarget());
2851 spirv::BranchConditionalOp::create(
2852 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2853 branchCondOp.getTrueBlockArguments(), blockArgs,
2854 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2855 branchCondOp.getFalseBlock());
2857 branchCondOp.erase();
2858 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2859 if (
target == switchOp.getDefaultTarget()) {
2863 spirv::SwitchOp::create(
2864 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2865 switchOp.getDefaultTarget(), blockArgs, literals,
2866 switchOp.getTargets(), targetOperands);
2870 auto it = llvm::find(targets,
target);
2871 assert(it != targets.end());
2872 size_t index = std::distance(targets.begin(), it);
2873 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2876 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2880 logger.startLine() <<
"[phi] after creating block argument:\n";
2882 logger.startLine() <<
"\n";
2885 blockPhiInfo.clear();
2890 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2898 for (
auto [block, mergeInfo] : blockMergeInfoCopy) {
2900 if (mergeInfo.continueBlock)
2903 if (!block->mightHaveTerminator())
2906 Operation *terminator = block->getTerminator();
2909 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2913 bool splitHeaderMergeBlock =
false;
2914 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2915 if (mergeInfo.mergeBlock == block)
2916 splitHeaderMergeBlock =
true;
2923 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2926 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2930 blockMergeInfo.erase(block);
2931 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2939 if (!options.enableControlFlowStructurization) {
2943 <<
"//----- [cf] skip structurizing control flow -----//\n";
2951 <<
"//----- [cf] start structurizing control flow -----//\n";
2956 logger.startLine() <<
"[cf] split conditional blocks\n";
2957 logger.startLine() <<
"\n";
2964 while (!blockMergeInfo.empty()) {
2965 Block *headerBlock = blockMergeInfo.
begin()->first;
2969 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2970 headerBlock->
print(logger.getOStream());
2971 logger.startLine() <<
"\n";
2975 assert(mergeBlock &&
"merge block cannot be nullptr");
2977 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2979 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2980 mergeBlock->print(logger.getOStream());
2981 logger.startLine() <<
"\n";
2985 LLVM_DEBUG(
if (continueBlock) {
2986 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2987 continueBlock->print(logger.getOStream());
2988 logger.startLine() <<
"\n";
2992 blockMergeInfo.
erase(blockMergeInfo.begin());
2993 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2994 blockMergeInfo, headerBlock,
2995 mergeBlock, continueBlock
3001 if (failed(structurizer.structurize()))
3008 <<
"//--- [cf] completed structurizing control flow ---//\n";
3021 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3022 if (fileName.empty())
3023 fileName =
"<unknown>";
3035 if (operands.size() != 3)
3036 return emitError(unknownLoc,
"OpLine must have 3 operands");
3037 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3045 if (operands.size() < 2)
3046 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3048 if (!debugInfoMap.lookup(operands[0]).empty())
3050 "duplicate debug string found for result <id> ")
3053 unsigned wordIndex = 1;
3055 if (wordIndex != operands.size())
3057 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
void print(raw_ostream &os, const OpPrintingFlags &flags={})
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processSamplerType(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.