23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecorate with ")
238 << decorationName <<
" needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::Location:
289 case spirv::Decoration::SpecId:
290 case spirv::Decoration::Index:
291 if (words.size() != 3) {
292 return emitError(unknownLoc,
"OpDecorate with ")
293 << decorationName <<
" needs a single integer literal";
295 decorations[words[0]].set(
296 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
298 case spirv::Decoration::BuiltIn:
299 if (words.size() != 3) {
300 return emitError(unknownLoc,
"OpDecorate with ")
301 << decorationName <<
" needs a single integer literal";
303 decorations[words[0]].set(
304 symbol, opBuilder.getStringAttr(
305 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
307 case spirv::Decoration::ArrayStride:
308 if (words.size() != 3) {
309 return emitError(unknownLoc,
"OpDecorate with ")
310 << decorationName <<
" needs a single integer literal";
312 typeDecorations[words[0]] = words[2];
314 case spirv::Decoration::LinkageAttributes: {
315 if (words.size() < 4) {
316 return emitError(unknownLoc,
"OpDecorate with ")
318 <<
" needs at least 1 string and 1 integer literal";
326 unsigned wordIndex = 2;
328 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
329 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
330 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
331 StringAttr::get(context, linkageName), linkageTypeAttr);
332 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
335 case spirv::Decoration::Aliased:
336 case spirv::Decoration::AliasedPointer:
337 case spirv::Decoration::Block:
338 case spirv::Decoration::BufferBlock:
339 case spirv::Decoration::Flat:
340 case spirv::Decoration::NonReadable:
341 case spirv::Decoration::NonWritable:
342 case spirv::Decoration::NoPerspective:
343 case spirv::Decoration::NoSignedWrap:
344 case spirv::Decoration::NoUnsignedWrap:
345 case spirv::Decoration::RelaxedPrecision:
346 case spirv::Decoration::Restrict:
347 case spirv::Decoration::RestrictPointer:
348 case spirv::Decoration::NoContraction:
349 case spirv::Decoration::Constant:
350 case spirv::Decoration::Invariant:
351 case spirv::Decoration::Patch:
352 case spirv::Decoration::Coherent:
353 if (words.size() != 2) {
354 return emitError(unknownLoc,
"OpDecorate with ")
355 << decorationName <<
" needs a single target <id>";
357 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
359 case spirv::Decoration::CacheControlLoadINTEL: {
361 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
362 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368 case spirv::Decoration::CacheControlStoreINTEL: {
370 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
371 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
384spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
386 if (words.size() < 3) {
388 "OpMemberDecorate must have at least 3 operands");
391 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
392 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
394 " missing offset specification in OpMemberDecorate with "
395 "Offset decoration");
397 ArrayRef<uint32_t> decorationOperands;
398 if (words.size() > 3) {
399 decorationOperands = words.slice(3);
401 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
405LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
406 if (words.size() < 3) {
407 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
409 unsigned wordIndex = 2;
411 if (wordIndex != words.size()) {
413 "unexpected trailing words in OpMemberName instruction");
415 memberNameMap[words[0]][words[1]] = name;
421 if (!decorations.contains(argID)) {
422 argAttrs[argIndex] = DictionaryAttr::get(context, {});
426 spirv::DecorationAttr foundDecorationAttr;
428 for (
auto decoration :
429 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
430 spirv::Decoration::AliasedPointer,
431 spirv::Decoration::RestrictPointer}) {
433 if (decAttr.getName() !=
437 if (foundDecorationAttr)
439 "more than one Aliased/Restrict decorations for "
440 "function argument with result <id> ")
443 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
448 spirv::Decoration::RelaxedPrecision))) {
453 if (foundDecorationAttr)
454 return emitError(unknownLoc,
"already found a decoration for function "
455 "argument with result <id> ")
458 foundDecorationAttr = spirv::DecorationAttr::get(
459 context, spirv::Decoration::RelaxedPrecision);
463 if (!foundDecorationAttr)
464 return emitError(unknownLoc,
"unimplemented decoration support for "
465 "function argument with result <id> ")
468 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
469 foundDecorationAttr);
470 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
477 return emitError(unknownLoc,
"found function inside function");
481 if (operands.size() != 4) {
482 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
486 return emitError(unknownLoc,
"undefined result type from <id> ")
490 uint32_t fnID = operands[1];
491 if (funcMap.count(fnID)) {
492 return emitError(unknownLoc,
"duplicate function definition/declaration");
495 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
497 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
501 if (!fnType || !isa<FunctionType>(fnType)) {
502 return emitError(unknownLoc,
"unknown function type from <id> ")
505 auto functionType = cast<FunctionType>(fnType);
507 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
508 (functionType.getNumResults() == 1 &&
509 functionType.getResult(0) != resultType)) {
510 return emitError(unknownLoc,
"mismatch in function type ")
511 << functionType <<
" and return type " << resultType <<
" specified";
515 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
516 functionType, fnControl.value());
518 if (decorations.count(fnID)) {
519 for (
auto attr : decorations[fnID].getAttrs()) {
520 funcOp->setAttr(attr.getName(), attr.getValue());
523 curFunction = funcMap[fnID] = funcOp;
524 auto *entryBlock = funcOp.addEntryBlock();
527 <<
"//===-------------------------------------------===//\n";
528 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
529 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
530 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
531 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
536 argAttrs.resize(functionType.getNumInputs());
539 if (functionType.getNumInputs()) {
540 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
541 auto argType = functionType.getInput(i);
542 spirv::Opcode opcode = spirv::Opcode::OpNop;
545 spirv::Opcode::OpFunctionParameter))) {
548 if (opcode != spirv::Opcode::OpFunctionParameter) {
551 "missing OpFunctionParameter instruction for argument ")
554 if (operands.size() != 2) {
557 "expected result type and result <id> for OpFunctionParameter");
559 auto argDefinedType =
getType(operands[0]);
560 if (!argDefinedType || argDefinedType != argType) {
562 "mismatch in argument type between function type "
564 << functionType <<
" and argument type definition "
565 << argDefinedType <<
" at argument " << i;
568 return emitError(unknownLoc,
"duplicate definition of result <id> ")
575 auto argValue = funcOp.getArgument(i);
576 valueMap[operands[1]] = argValue;
580 if (llvm::any_of(argAttrs, [](
Attribute attr) {
581 auto argAttr = cast<DictionaryAttr>(attr);
582 return !argAttr.empty();
584 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
589 auto linkageAttr = funcOp.getLinkageAttributes();
590 auto hasImportLinkage =
591 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
592 spirv::LinkageType::Import);
593 if (hasImportLinkage)
600 spirv::Opcode opcode = spirv::Opcode::OpNop;
609 spirv::Opcode::OpFunctionEnd))) {
612 if (opcode == spirv::Opcode::OpFunctionEnd) {
615 if (opcode != spirv::Opcode::OpLabel) {
616 return emitError(unknownLoc,
"a basic block must start with OpLabel");
618 if (instOperands.size() != 1) {
619 return emitError(unknownLoc,
"OpLabel should only have result <id>");
621 blockMap[instOperands[0]] = entryBlock;
629 spirv::Opcode::OpFunctionEnd)) &&
630 opcode != spirv::Opcode::OpFunctionEnd) {
635 if (opcode != spirv::Opcode::OpFunctionEnd) {
645 if (!operands.empty()) {
646 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
657 curFunction = std::nullopt;
662 <<
"//===-------------------------------------------===//\n";
669 if (operands.size() < 2) {
671 "missing graph defintion in OpGraphEntryPointARM");
674 unsigned wordIndex = 0;
675 uint32_t graphID = operands[wordIndex++];
676 if (!graphMap.contains(graphID)) {
678 "missing graph definition/declaration with id ")
682 spirv::GraphARMOp graphARM = graphMap[graphID];
684 graphARM.setSymName(name);
685 graphARM.setEntryPoint(
true);
688 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
690 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
692 return emitError(unknownLoc,
"undefined result <id> ")
693 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
699 opBuilder.setInsertionPoint(graphARM);
700 spirv::GraphEntryPointARMOp::create(
701 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
702 opBuilder.getArrayAttr(interface));
710 return emitError(unknownLoc,
"found graph inside graph");
713 if (operands.size() < 2) {
714 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
718 if (!type || !isa<GraphType>(type)) {
719 return emitError(unknownLoc,
"unknown graph type from <id> ")
722 auto graphType = cast<GraphType>(type);
723 if (graphType.getNumResults() <= 0) {
724 return emitError(unknownLoc,
"expected at least one result");
727 uint32_t graphID = operands[1];
728 if (graphMap.count(graphID)) {
729 return emitError(unknownLoc,
"duplicate graph definition/declaration");
734 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
735 curGraph = graphMap[graphID] = graphOp;
736 Block *entryBlock = graphOp.addEntryBlock();
739 <<
"//===-------------------------------------------===//\n";
740 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
741 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
742 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
743 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
748 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
749 spirv::Opcode opcode;
752 spirv::Opcode::OpGraphInputARM))) {
755 if (operands.size() != 3) {
756 return emitError(unknownLoc,
"expected result type, result <id> and "
757 "input index for OpGraphInputARM");
761 if (!argDefinedType) {
762 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
765 if (argDefinedType != argType) {
767 "mismatch in argument type between graph type "
769 << graphType <<
" and argument type definition " << argDefinedType
770 <<
" at argument " <<
index;
773 return emitError(unknownLoc,
"duplicate definition of result <id> ")
778 if (!inputIndexAttr) {
780 "unable to read inputIndex value from constant op ")
783 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
784 valueMap[operands[1]] = argValue;
787 graphOutputs.resize(graphType.getNumResults());
793 blockMap[graphID] = entryBlock;
800 spirv::Opcode opcode;
810 }
while (opcode != spirv::Opcode::OpGraphEndARM);
817 if (operands.size() != 2) {
820 "expected value id and output index for OpGraphSetOutputARM");
823 uint32_t
id = operands[0];
826 return emitError(unknownLoc,
"could not find result <id> ") << id;
830 if (!outputIndexAttr) {
832 "unable to read outputIndex value from constant op ")
835 graphOutputs[outputIndexAttr.getInt()] = value;
842 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
845 if (!operands.empty()) {
846 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
850 curGraph = std::nullopt;
851 graphOutputs.clear();
856 <<
"//===-------------------------------------------===//\n";
861std::optional<std::pair<Attribute, Type>>
863 auto constIt = constantMap.find(
id);
864 if (constIt == constantMap.end())
866 return constIt->getSecond();
869std::optional<std::pair<Attribute, Type>>
871 if (
auto it = constantCompositeReplicateMap.find(
id);
872 it != constantCompositeReplicateMap.end())
877std::optional<spirv::SpecConstOperationMaterializationInfo>
879 auto constIt = specConstOperationMap.find(
id);
880 if (constIt == specConstOperationMap.end())
882 return constIt->getSecond();
886 auto funcName = nameMap.lookup(
id).str();
887 if (funcName.empty()) {
888 funcName =
"spirv_fn_" + std::to_string(
id);
894 std::string graphName = nameMap.lookup(
id).str();
895 if (graphName.empty()) {
896 graphName =
"spirv_graph_" + std::to_string(
id);
902 auto constName = nameMap.lookup(
id).str();
903 if (constName.empty()) {
904 constName =
"spirv_spec_const_" + std::to_string(
id);
911 TypedAttr defaultValue) {
913 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
915 if (decorations.count(resultID)) {
916 for (
auto attr : decorations[resultID].getAttrs())
917 op->setAttr(attr.getName(), attr.getValue());
919 specConstMap[resultID] = op;
923std::optional<spirv::GraphConstantARMOpMaterializationInfo>
925 auto graphConstIt = graphConstantMap.find(
id);
926 if (graphConstIt == graphConstantMap.end())
928 return graphConstIt->getSecond();
933 unsigned wordIndex = 0;
934 if (operands.size() < 3) {
937 "OpVariable needs at least 3 operands, type, <id> and storage class");
941 auto type =
getType(operands[wordIndex]);
943 return emitError(unknownLoc,
"unknown result type <id> : ")
944 << operands[wordIndex];
946 auto ptrType = dyn_cast<spirv::PointerType>(type);
949 "expected a result type <id> to be a spirv.ptr, found : ")
955 auto variableID = operands[wordIndex];
956 auto variableName = nameMap.lookup(variableID).str();
957 if (variableName.empty()) {
958 variableName =
"spirv_var_" + std::to_string(variableID);
963 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
964 if (ptrType.getStorageClass() != storageClass) {
965 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
966 << type <<
" and that specified in OpVariable instruction : "
967 << stringifyStorageClass(storageClass);
974 if (wordIndex < operands.size()) {
984 return emitError(unknownLoc,
"unknown <id> ")
985 << operands[wordIndex] <<
"used as initializer";
987 initializer = SymbolRefAttr::get(op);
990 if (wordIndex != operands.size()) {
992 "found more operands than expected when deserializing "
993 "OpVariable instruction, only ")
994 << wordIndex <<
" of " << operands.size() <<
" processed";
997 auto varOp = spirv::GlobalVariableOp::create(
998 opBuilder, loc, TypeAttr::get(type),
999 opBuilder.getStringAttr(variableName), initializer);
1002 if (decorations.count(variableID)) {
1003 for (
auto attr : decorations[variableID].getAttrs())
1004 varOp->setAttr(attr.getName(), attr.getValue());
1006 globalVariableMap[variableID] = varOp;
1015 return dyn_cast<IntegerAttr>(constInfo->first);
1019 if (operands.size() < 2) {
1020 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1022 if (!nameMap.lookup(operands[0]).empty()) {
1023 return emitError(unknownLoc,
"duplicate name found for result <id> ")
1026 unsigned wordIndex = 1;
1028 if (wordIndex != operands.size()) {
1030 "unexpected trailing words in OpName instruction");
1032 nameMap[operands[0]] = name;
1042 if (operands.empty()) {
1043 return emitError(unknownLoc,
"type instruction with opcode ")
1044 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1049 if (typeMap.count(operands[0])) {
1050 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1055 case spirv::Opcode::OpTypeVoid:
1056 if (operands.size() != 1)
1057 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1058 typeMap[operands[0]] = opBuilder.getNoneType();
1060 case spirv::Opcode::OpTypeBool:
1061 if (operands.size() != 1)
1062 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1063 typeMap[operands[0]] = opBuilder.getI1Type();
1065 case spirv::Opcode::OpTypeInt: {
1066 if (operands.size() != 3)
1068 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1077 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1078 : IntegerType::SignednessSemantics::Signless;
1079 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1081 case spirv::Opcode::OpTypeFloat: {
1082 if (operands.size() != 2 && operands.size() != 3)
1084 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1085 "or 3 operands (type, bitwidth, encoding), but got ")
1087 uint32_t bitWidth = operands[1];
1090 if (operands.size() == 2) {
1093 floatTy = opBuilder.getF16Type();
1096 floatTy = opBuilder.getF32Type();
1099 floatTy = opBuilder.getF64Type();
1102 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1107 if (operands.size() == 3) {
1108 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1110 floatTy = opBuilder.getBF16Type();
1111 else if (spirv::FPEncoding(operands[2]) ==
1112 spirv::FPEncoding::Float8E4M3EXT &&
1114 floatTy = opBuilder.getF8E4M3FNType();
1115 else if (spirv::FPEncoding(operands[2]) ==
1116 spirv::FPEncoding::Float8E5M2EXT &&
1118 floatTy = opBuilder.getF8E5M2Type();
1120 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1121 << operands[2] <<
" and bitWidth " << bitWidth;
1124 typeMap[operands[0]] = floatTy;
1126 case spirv::Opcode::OpTypeVector: {
1127 if (operands.size() != 3) {
1130 "OpTypeVector must have element type and count parameters");
1134 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1137 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1139 case spirv::Opcode::OpTypePointer: {
1142 case spirv::Opcode::OpTypeArray:
1144 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1146 case spirv::Opcode::OpTypeFunction:
1148 case spirv::Opcode::OpTypeImage:
1150 case spirv::Opcode::OpTypeSampledImage:
1152 case spirv::Opcode::OpTypeRuntimeArray:
1154 case spirv::Opcode::OpTypeStruct:
1156 case spirv::Opcode::OpTypeMatrix:
1158 case spirv::Opcode::OpTypeTensorARM:
1160 case spirv::Opcode::OpTypeGraphARM:
1163 return emitError(unknownLoc,
"unhandled type instruction");
1170 if (operands.size() != 3)
1171 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1173 auto pointeeType =
getType(operands[2]);
1175 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1178 uint32_t typePointerID = operands[0];
1179 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1182 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1183 deferredStructIt != std::end(deferredStructTypesInfos);) {
1184 for (
auto *unresolvedMemberIt =
1185 std::begin(deferredStructIt->unresolvedMemberTypes);
1186 unresolvedMemberIt !=
1187 std::end(deferredStructIt->unresolvedMemberTypes);) {
1188 if (unresolvedMemberIt->first == typePointerID) {
1192 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1193 typeMap[typePointerID];
1194 unresolvedMemberIt =
1195 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1197 ++unresolvedMemberIt;
1201 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1203 auto structType = deferredStructIt->deferredStructType;
1205 assert(structType &&
"expected a spirv::StructType");
1206 assert(structType.isIdentified() &&
"expected an indentified struct");
1208 if (failed(structType.trySetBody(
1209 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1210 deferredStructIt->memberDecorationsInfo,
1211 deferredStructIt->structDecorationsInfo)))
1214 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1225 if (operands.size() != 3) {
1227 "OpTypeArray must have element type and count parameters");
1232 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1240 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1241 << operands[2] <<
"can only come from normal constant right now";
1244 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1245 count = intVal.getValue().getZExtValue();
1247 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1248 "scalar integer constant instruction");
1252 elementTy, count, typeDecorations.lookup(operands[0]));
1258 assert(!operands.empty() &&
"No operands for processing function type");
1259 if (operands.size() == 1) {
1260 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1262 auto returnType =
getType(operands[1]);
1264 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1267 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1268 auto ty =
getType(operands[i]);
1270 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1272 argTypes.push_back(ty);
1278 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1284 if (operands.size() != 6) {
1286 "OpTypeCooperativeMatrixKHR must have element type, "
1287 "scope, row and column parameters, and use");
1293 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1297 std::optional<spirv::Scope> scope =
1302 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1311 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1312 "undefined constant <id> ")
1316 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1317 "references undefined constant <id> ")
1321 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1322 "undefined constant <id> ")
1325 unsigned rows = rowsAttr.getInt();
1326 unsigned columns = columnsAttr.getInt();
1328 std::optional<spirv::CooperativeMatrixUseKHR> use =
1329 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1333 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1337 typeMap[operands[0]] =
1344 if (operands.size() != 2) {
1345 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1350 "OpTypeRuntimeArray references undefined <id> ")
1354 memberType, typeDecorations.lookup(operands[0]));
1362 if (operands.empty()) {
1363 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1366 if (operands.size() == 1) {
1368 typeMap[operands[0]] =
1377 for (
auto op : llvm::drop_begin(operands, 1)) {
1379 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1381 if (!memberType && !typeForwardPtr)
1382 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1386 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1388 memberTypes.push_back(memberType);
1393 if (memberDecorationMap.count(operands[0])) {
1394 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1395 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1396 if (allMemberDecorations.count(memberIndex)) {
1397 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1399 if (memberDecoration.first == spirv::Decoration::Offset) {
1401 if (offsetInfo.empty()) {
1402 offsetInfo.resize(memberTypes.size());
1404 offsetInfo[memberIndex] = memberDecoration.second[0];
1406 auto intType = mlir::IntegerType::get(context, 32);
1407 if (!memberDecoration.second.empty()) {
1408 memberDecorationsInfo.emplace_back(
1409 memberIndex, memberDecoration.first,
1410 IntegerAttr::get(intType, memberDecoration.second[0]));
1412 memberDecorationsInfo.emplace_back(
1413 memberIndex, memberDecoration.first, UnitAttr::get(context));
1422 if (decorations.count(operands[0])) {
1425 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1426 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1427 assert(decoration.has_value());
1428 structDecorationsInfo.emplace_back(decoration.value(),
1429 decorationAttr.getValue());
1433 uint32_t structID = operands[0];
1434 std::string structIdentifier = nameMap.lookup(structID).str();
1436 if (structIdentifier.empty()) {
1437 assert(unresolvedMemberTypes.empty() &&
1438 "didn't expect unresolved member types");
1440 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1443 typeMap[structID] = structTy;
1445 if (!unresolvedMemberTypes.empty())
1446 deferredStructTypesInfos.push_back(
1447 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1448 memberDecorationsInfo, structDecorationsInfo});
1449 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1450 memberDecorationsInfo,
1451 structDecorationsInfo)))
1462 if (operands.size() != 3) {
1464 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1465 " (result_id, column_type, and column_count)");
1471 "OpTypeMatrix references undefined column type.")
1475 uint32_t colsCount = operands[2];
1482 unsigned size = operands.size();
1483 if (size < 2 || size > 4)
1484 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1485 "(result_id, element_type, (rank), (shape)) ")
1491 "OpTypeTensorARM references undefined element type ")
1501 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1502 "scalar integer constant instruction");
1503 unsigned rank = rankAttr.getValue().getZExtValue();
1510 std::optional<std::pair<Attribute, Type>> shapeInfo =
1513 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1514 "constant instruction of type OpTypeArray");
1516 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1518 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1519 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1521 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1523 shape.push_back(dimIntAttr.getValue().getSExtValue());
1531 unsigned size = operands.size();
1533 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1534 "(result_id, num_inputs, (inout0_type, "
1535 "inout1_type, ...))")
1538 uint32_t numInputs = operands[1];
1541 for (
unsigned i = 2; i < size; ++i) {
1545 "OpTypeGraphARM references undefined element type.")
1548 if (i - 2 >= numInputs) {
1549 returnTypes.push_back(inOutTy);
1551 argTypes.push_back(inOutTy);
1554 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1560 if (operands.size() != 2)
1562 "OpTypeForwardPointer instruction must have two operands");
1564 typeForwardPointerIDs.insert(operands[0]);
1574 if (operands.size() != 8)
1577 "OpTypeImage with non-eight operands are not supported yet");
1581 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1584 auto dim = spirv::symbolizeDim(operands[2]);
1586 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1589 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1591 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1594 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1596 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1599 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1601 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1603 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1604 if (!samplerUseInfo)
1605 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1608 auto format = spirv::symbolizeImageFormat(operands[7]);
1610 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1614 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1615 samplingInfo.value(), samplerUseInfo.value(), format.value());
1621 if (operands.size() != 2)
1622 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1627 "OpTypeSampledImage references undefined <id>: ")
1640 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1642 if (operands.size() < 2) {
1644 << opname <<
" must have type <id> and result <id>";
1646 if (operands.size() < 3) {
1648 << opname <<
" must have at least 1 more parameter";
1653 return emitError(unknownLoc,
"undefined result type from <id> ")
1657 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1658 if (bitwidth == 64) {
1659 if (operands.size() == 4) {
1663 << opname <<
" should have 2 parameters for 64-bit values";
1665 if (bitwidth <= 32) {
1666 if (operands.size() == 3) {
1672 <<
" should have 1 parameter for values with no more than 32 bits";
1674 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1678 auto resultID = operands[1];
1680 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1681 auto bitwidth = intType.getWidth();
1682 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1687 if (bitwidth == 64) {
1694 } words = {operands[2], operands[3]};
1695 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1696 }
else if (bitwidth <= 32) {
1697 value = APInt(bitwidth, operands[2],
true,
1701 auto attr = opBuilder.getIntegerAttr(intType, value);
1708 constantMap.try_emplace(resultID, attr, intType);
1714 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1715 auto bitwidth = floatType.getWidth();
1716 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1721 if (floatType.isF64()) {
1728 } words = {operands[2], operands[3]};
1729 value = APFloat(llvm::bit_cast<double>(words));
1730 }
else if (floatType.isF32()) {
1731 value = APFloat(llvm::bit_cast<float>(operands[2]));
1732 }
else if (floatType.isF16()) {
1733 APInt data(16, operands[2]);
1734 value = APFloat(APFloat::IEEEhalf(), data);
1735 }
else if (floatType.isBF16()) {
1736 APInt data(16, operands[2]);
1737 value = APFloat(APFloat::BFloat(), data);
1738 }
else if (floatType.isF8E4M3FN()) {
1739 APInt data(8, operands[2]);
1740 value = APFloat(APFloat::Float8E4M3FN(), data);
1741 }
else if (floatType.isF8E5M2()) {
1742 APInt data(8, operands[2]);
1743 value = APFloat(APFloat::Float8E5M2(), data);
1746 auto attr = opBuilder.getFloatAttr(floatType, value);
1752 constantMap.try_emplace(resultID, attr, floatType);
1758 return emitError(unknownLoc,
"OpConstant can only generate values of "
1759 "scalar integer or floating-point type");
1764 if (operands.size() != 2) {
1766 << (isSpec ?
"Spec" :
"") <<
"Constant"
1767 << (isTrue ?
"True" :
"False")
1768 <<
" must have type <id> and result <id>";
1771 auto attr = opBuilder.getBoolAttr(isTrue);
1772 auto resultID = operands[1];
1778 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1786 if (operands.size() < 2) {
1788 "OpConstantComposite must have type <id> and result <id>");
1790 if (operands.size() < 3) {
1792 "OpConstantComposite must have at least 1 parameter");
1797 return emitError(unknownLoc,
"undefined result type from <id> ")
1802 elements.reserve(operands.size() - 2);
1803 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1806 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1807 << operands[i] <<
" must come from a normal constant";
1809 elements.push_back(elementInfo->first);
1812 auto resultID = operands[1];
1813 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1816 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1817 for (
auto value : denseElemAttr.getValues<
Attribute>())
1818 flattenedElems.push_back(value);
1820 flattenedElems.push_back(element);
1824 constantMap.try_emplace(resultID, attr, tensorType);
1825 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1829 constantMap.try_emplace(resultID, attr, shapedType);
1830 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1831 auto attr = opBuilder.getArrayAttr(elements);
1832 constantMap.try_emplace(resultID, attr, resultType);
1834 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1843 if (operands.size() != 3) {
1846 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1852 return emitError(unknownLoc,
"undefined result type from <id> ")
1856 auto compositeType = dyn_cast<CompositeType>(resultType);
1857 if (!compositeType) {
1859 "result type from <id> is not a composite type")
1863 uint32_t resultID = operands[1];
1864 uint32_t constantID = operands[2];
1866 std::optional<std::pair<Attribute, Type>> constantInfo =
1868 if (constantInfo.has_value()) {
1869 constantCompositeReplicateMap.try_emplace(
1870 resultID, constantInfo.value().first, resultType);
1874 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1876 if (replicatedConstantCompositeInfo.has_value()) {
1877 constantCompositeReplicateMap.try_emplace(
1878 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1882 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1884 <<
" must come from a normal constant or a "
1885 "OpConstantCompositeReplicateEXT";
1890 if (operands.size() < 2) {
1893 "OpSpecConstantComposite must have type <id> and result <id>");
1895 if (operands.size() < 3) {
1897 "OpSpecConstantComposite must have at least 1 parameter");
1902 return emitError(unknownLoc,
"undefined result type from <id> ")
1906 auto resultID = operands[1];
1910 elements.reserve(operands.size() - 2);
1911 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1913 elements.push_back(SymbolRefAttr::get(elementInfo));
1916 auto op = spirv::SpecConstantCompositeOp::create(
1917 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1918 opBuilder.getArrayAttr(elements));
1919 specConstCompositeMap[resultID] = op;
1926 if (operands.size() != 3) {
1927 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1928 "3 operands but found ")
1934 return emitError(unknownLoc,
"undefined result type from <id> ")
1938 auto compositeType = dyn_cast<CompositeType>(resultType);
1939 if (!compositeType) {
1941 "result type from <id> is not a composite type")
1945 uint32_t resultID = operands[1];
1948 spirv::SpecConstantOp constituentSpecConstantOp =
1950 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1951 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1952 SymbolRefAttr::get(constituentSpecConstantOp));
1954 specConstCompositeReplicateMap[resultID] = op;
1961 if (operands.size() < 3)
1962 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1963 "result <id>, and operand opcode");
1965 uint32_t resultTypeID = operands[0];
1968 return emitError(unknownLoc,
"undefined result type from <id> ")
1971 uint32_t resultID = operands[1];
1972 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1973 auto emplaceResult = specConstOperationMap.try_emplace(
1976 enclosedOpcode, resultTypeID,
1979 if (!emplaceResult.second)
1980 return emitError(unknownLoc,
"value with <id>: ")
1981 << resultID <<
" is probably defined before.";
1987 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2003 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2004 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
2007 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2008 enclosedOpResultTypeAndOperands.push_back(fakeID);
2009 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2010 enclosedOpOperands.end());
2025 auto specConstOperationOp =
2026 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2028 Region &body = specConstOperationOp.getBody();
2030 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2037 opBuilder.setInsertionPointToEnd(&block);
2039 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2040 return specConstOperationOp.getResult();
2045 if (operands.size() != 2) {
2047 "OpConstantNull must only have type <id> and result <id>");
2052 return emitError(unknownLoc,
"undefined result type from <id> ")
2056 auto resultID = operands[1];
2058 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2059 attr = opBuilder.getZeroAttr(resultType);
2060 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2061 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2068 constantMap.try_emplace(resultID, attr, resultType);
2072 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2078 if (operands.size() < 3) {
2080 <<
"OpGraphConstantARM must have at least 2 operands";
2085 return emitError(unknownLoc,
"undefined result type from <id> ")
2089 uint32_t resultID = operands[1];
2091 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2092 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2095 APInt graph_constant_id = APInt(32, operands[2],
true);
2096 Type i32Ty = opBuilder.getIntegerType(32);
2097 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2098 graphConstantMap.try_emplace(
2110 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2111 <<
" @ " << block <<
"\n");
2118 auto *block = curFunction->addBlock();
2119 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2120 <<
" @ " << block <<
"\n");
2121 return blockMap[id] = block;
2126 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2129 if (operands.size() != 1) {
2130 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2138 spirv::BranchOp::create(opBuilder, loc,
target);
2148 "OpBranchConditional must appear inside a block");
2151 if (operands.size() != 3 && operands.size() != 5) {
2153 "OpBranchConditional must have condition, true label, "
2154 "false label, and optionally two branch weights");
2157 auto condition =
getValue(operands[0]);
2161 std::optional<std::pair<uint32_t, uint32_t>> weights;
2162 if (operands.size() == 5) {
2163 weights = std::make_pair(operands[3], operands[4]);
2169 spirv::BranchConditionalOp::create(
2170 opBuilder, loc, condition, trueBlock,
2180 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2183 if (operands.size() != 1) {
2184 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2187 auto labelID = operands[0];
2190 LLVM_DEBUG(logger.startLine()
2191 <<
"[block] populating block " << block <<
"\n");
2193 assert(block->empty() &&
"re-deserialize the same block!");
2195 opBuilder.setInsertionPointToStart(block);
2196 blockMap[labelID] = curBlock = block;
2203 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2208 LLVM_DEBUG(logger.startLine()
2209 <<
"[block] populating block " << block <<
"\n");
2211 assert(block->
empty() &&
"re-deserialize the same block!");
2213 opBuilder.setInsertionPointToStart(block);
2214 blockMap[graphID] = curBlock = block;
2222 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2225 if (operands.size() < 2) {
2228 "OpSelectionMerge must specify merge target and selection control");
2233 auto selectionControl = operands[1];
2235 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2239 "a block cannot have more than one OpSelectionMerge instruction");
2248 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2251 if (operands.size() < 3) {
2252 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2253 "continue target and loop control");
2259 uint32_t loopControl = operands[2];
2262 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2266 "a block cannot have more than one OpLoopMerge instruction");
2274 return emitError(unknownLoc,
"OpPhi must appear in a block");
2277 if (operands.size() < 4) {
2278 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2279 "and variable-parent pairs");
2284 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2285 valueMap[operands[1]] = blockArg;
2286 LLVM_DEBUG(logger.startLine()
2287 <<
"[phi] created block argument " << blockArg
2288 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2292 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2293 uint32_t value = operands[i];
2295 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2296 blockPhiInfo[predecessorTargetPair].push_back(value);
2297 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2298 <<
" with arg id = " << value <<
"\n");
2306 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2308 if (operands.size() < 2)
2309 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2310 "a default target");
2312 if (operands.size() % 2)
2314 "OpSwitch must at have an even number of operands: "
2315 "selector, default target and any number of literal and "
2316 "label <id> pairs");
2324 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2325 literals.push_back(operands[i]);
2330 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2339class ControlFlowStructurizer {
2342 ControlFlowStructurizer(
Location loc, uint32_t control,
2345 llvm::ScopedPrinter &logger)
2346 : location(loc), control(control), blockMergeInfo(mergeInfo),
2347 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2350 ControlFlowStructurizer(
Location loc, uint32_t control,
2353 : location(loc), control(control), blockMergeInfo(mergeInfo),
2354 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2364 LogicalResult structurize();
2369 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2372 spirv::LoopOp createLoopOp(uint32_t loopControl);
2375 void collectBlocksInConstruct();
2384 Block *continueBlock;
2390 llvm::ScopedPrinter &logger;
2396ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2399 OpBuilder builder(&mergeBlock->front());
2401 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2402 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2403 selectionOp.addMergeBlock(builder);
2408spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2411 OpBuilder builder(&mergeBlock->front());
2413 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2414 auto loopOp = spirv::LoopOp::create(builder, location, control);
2415 loopOp.addEntryAndMergeBlock(builder);
2420void ControlFlowStructurizer::collectBlocksInConstruct() {
2421 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2424 constructBlocks.insert(headerBlock);
2428 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2429 for (
auto *successor : constructBlocks[i]->getSuccessors())
2430 if (successor != mergeBlock)
2431 constructBlocks.insert(successor);
2435LogicalResult ControlFlowStructurizer::structurize() {
2436 Operation *op =
nullptr;
2437 bool isLoop = continueBlock !=
nullptr;
2439 if (
auto loopOp = createLoopOp(control))
2440 op = loopOp.getOperation();
2442 if (
auto selectionOp = createSelectionOp(control))
2443 op = selectionOp.getOperation();
2452 mapper.
map(mergeBlock, &body.
back());
2454 collectBlocksInConstruct();
2475 OpBuilder builder(body);
2476 for (
auto *block : constructBlocks) {
2479 auto *newBlock = builder.createBlock(&body.
back());
2480 mapper.
map(block, newBlock);
2481 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2482 <<
" from block " << block <<
"\n");
2484 for (BlockArgument blockArg : block->getArguments()) {
2486 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2487 mapper.
map(blockArg, newArg);
2488 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2489 << blockArg <<
" to " << newArg <<
"\n");
2492 LLVM_DEBUG(logger.startLine()
2493 <<
"[cf] block " << block <<
" is a function entry block\n");
2496 for (
auto &op : *block)
2497 newBlock->push_back(op.
clone(mapper));
2501 auto remapOperands = [&](Operation *op) {
2503 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2504 operand.set(mappedOp);
2507 succOp.set(mappedOp);
2509 for (
auto &block : body)
2510 block.walk(remapOperands);
2518 headerBlock->replaceAllUsesWith(mergeBlock);
2521 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2522 headerBlock->getParentOp()->print(logger.getOStream());
2523 logger.startLine() <<
"\n";
2527 if (!mergeBlock->args_empty()) {
2528 return mergeBlock->getParentOp()->emitError(
2529 "OpPhi in loop merge block unsupported");
2535 for (BlockArgument blockArg : headerBlock->getArguments())
2536 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2540 SmallVector<Value, 4> blockArgs;
2541 if (!headerBlock->args_empty())
2542 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2546 builder.setInsertionPointToEnd(&body.front());
2547 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2548 ArrayRef<Value>(blockArgs));
2553 SmallVector<Value> valuesToYield;
2556 SmallVector<Value> outsideUses;
2570 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2575 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2576 valuesToYield.push_back(body.back().getArguments().back());
2577 outsideUses.push_back(blockArg);
2582 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2585 for (
auto *block : constructBlocks)
2586 block->dropAllReferences();
2591 for (
Block *block : constructBlocks) {
2592 for (Operation &op : *block) {
2596 outsideUses.push_back(
result);
2599 for (BlockArgument &arg : block->getArguments()) {
2600 if (!arg.use_empty()) {
2602 outsideUses.push_back(arg);
2607 assert(valuesToYield.size() == outsideUses.size());
2611 if (!valuesToYield.empty()) {
2612 LLVM_DEBUG(logger.startLine()
2613 <<
"[cf] yielding values from the selection / loop region\n");
2616 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2617 Operation *merge = llvm::getSingleElement(mergeOps);
2619 merge->setOperands(valuesToYield);
2627 builder.setInsertionPoint(&mergeBlock->front());
2629 Operation *newOp =
nullptr;
2632 newOp = spirv::LoopOp::create(builder, location,
2634 static_cast<spirv::LoopControl
>(control));
2636 newOp = spirv::SelectionOp::create(
2638 static_cast<spirv::SelectionControl
>(control));
2648 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2649 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2655 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2662 for (
auto *block : constructBlocks) {
2663 if (!block->use_empty())
2664 return emitError(block->getParent()->getLoc(),
2665 "failed control flow structurization: "
2666 "block has uses outside of the "
2667 "enclosing selection/loop construct");
2668 for (Operation &op : *block)
2670 return op.
emitOpError(
"failed control flow structurization: value has "
2671 "uses outside of the "
2672 "enclosing selection/loop construct");
2673 for (BlockArgument &arg : block->getArguments())
2674 if (!arg.use_empty())
2675 return emitError(arg.getLoc(),
"failed control flow structurization: "
2676 "block argument has uses outside of the "
2677 "enclosing selection/loop construct");
2681 for (
auto *block : constructBlocks) {
2721 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2722 auto it = blockMergeInfo.find(block);
2723 if (it != blockMergeInfo.end()) {
2725 Location loc = it->second.loc;
2729 return emitError(loc,
"failed control flow structurization: nested "
2730 "loop header block should be remapped!");
2732 Block *newContinue = it->second.continueBlock;
2736 return emitError(loc,
"failed control flow structurization: nested "
2737 "loop continue block should be remapped!");
2740 Block *newMerge = it->second.mergeBlock;
2742 newMerge = mappedTo;
2746 blockMergeInfo.
erase(it);
2747 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2754 if (block->walk(updateMergeInfo).wasInterrupted())
2762 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2763 <<
" to only contain a spirv.Branch op\n");
2767 builder.setInsertionPointToEnd(block);
2768 spirv::BranchOp::create(builder, location, mergeBlock);
2770 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2775 LLVM_DEBUG(logger.startLine()
2776 <<
"[cf] after structurizing construct with header block "
2777 << headerBlock <<
":\n"
2786 <<
"//----- [phi] start wiring up block arguments -----//\n";
2792 for (
const auto &info : blockPhiInfo) {
2793 Block *block = info.first.first;
2797 logger.startLine() <<
"[phi] block " << block <<
"\n";
2798 logger.startLine() <<
"[phi] before creating block argument:\n";
2800 logger.startLine() <<
"\n";
2806 opBuilder.setInsertionPoint(op);
2809 blockArgs.reserve(phiInfo.size());
2810 for (uint32_t valueId : phiInfo) {
2812 blockArgs.push_back(value);
2813 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2814 <<
" id = " << valueId <<
"\n");
2816 return emitError(unknownLoc,
"OpPhi references undefined value!");
2820 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2822 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2823 branchOp.getTarget(), blockArgs);
2825 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2826 assert((branchCondOp.getTrueBlock() ==
target ||
2827 branchCondOp.getFalseBlock() ==
target) &&
2828 "expected target to be either the true or false target");
2829 if (
target == branchCondOp.getTrueTarget())
2830 spirv::BranchConditionalOp::create(
2831 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2832 blockArgs, branchCondOp.getFalseBlockArguments(),
2833 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2834 branchCondOp.getFalseTarget());
2836 spirv::BranchConditionalOp::create(
2837 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2838 branchCondOp.getTrueBlockArguments(), blockArgs,
2839 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2840 branchCondOp.getFalseBlock());
2842 branchCondOp.erase();
2843 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2844 if (
target == switchOp.getDefaultTarget()) {
2848 spirv::SwitchOp::create(
2849 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2850 switchOp.getDefaultTarget(), blockArgs, literals,
2851 switchOp.getTargets(), targetOperands);
2855 auto it = llvm::find(targets,
target);
2856 assert(it != targets.end());
2857 size_t index = std::distance(targets.begin(), it);
2858 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2861 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2865 logger.startLine() <<
"[phi] after creating block argument:\n";
2867 logger.startLine() <<
"\n";
2870 blockPhiInfo.clear();
2875 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2883 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2885 auto &[block, mergeInfo] = *it;
2888 if (mergeInfo.continueBlock)
2891 if (!block->mightHaveTerminator())
2894 Operation *terminator = block->getTerminator();
2897 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2901 bool splitHeaderMergeBlock =
false;
2902 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2903 if (mergeInfo.mergeBlock == block)
2904 splitHeaderMergeBlock =
true;
2911 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2914 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2918 blockMergeInfo.erase(block);
2919 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2927 if (!options.enableControlFlowStructurization) {
2931 <<
"//----- [cf] skip structurizing control flow -----//\n";
2939 <<
"//----- [cf] start structurizing control flow -----//\n";
2944 logger.startLine() <<
"[cf] split conditional blocks\n";
2945 logger.startLine() <<
"\n";
2952 while (!blockMergeInfo.empty()) {
2953 Block *headerBlock = blockMergeInfo.
begin()->first;
2957 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2958 headerBlock->
print(logger.getOStream());
2959 logger.startLine() <<
"\n";
2963 assert(mergeBlock &&
"merge block cannot be nullptr");
2965 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2967 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2968 mergeBlock->print(logger.getOStream());
2969 logger.startLine() <<
"\n";
2973 LLVM_DEBUG(
if (continueBlock) {
2974 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2975 continueBlock->print(logger.getOStream());
2976 logger.startLine() <<
"\n";
2980 blockMergeInfo.
erase(blockMergeInfo.begin());
2981 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2982 blockMergeInfo, headerBlock,
2983 mergeBlock, continueBlock
2989 if (failed(structurizer.structurize()))
2996 <<
"//--- [cf] completed structurizing control flow ---//\n";
3009 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3010 if (fileName.empty())
3011 fileName =
"<unknown>";
3023 if (operands.size() != 3)
3024 return emitError(unknownLoc,
"OpLine must have 3 operands");
3025 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3033 if (operands.size() < 2)
3034 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3036 if (!debugInfoMap.lookup(operands[0]).empty())
3038 "duplicate debug string found for result <id> ")
3041 unsigned wordIndex = 1;
3043 if (wordIndex != operands.size())
3045 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.