23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
95 if (
failed(resolveDeferredIdDecorations()))
100 LLVM_DEBUG(logger.startLine()
101 <<
"//+++-------- completed deserialization --------+++//\n");
105OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
106 return std::move(module);
113OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
114 OpBuilder builder(context);
115 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
116 spirv::ModuleOp::build(builder, state);
120LogicalResult spirv::Deserializer::processHeader() {
123 "SPIR-V binary module must have a 5-word header");
126 return emitError(unknownLoc,
"incorrect magic number");
129 uint32_t majorVersion = (binary[1] << 8) >> 24;
130 uint32_t minorVersion = (binary[1] << 16) >> 24;
131 if (majorVersion == 1) {
132 switch (minorVersion) {
133#define MIN_VERSION_CASE(v) \
135 version = spirv::Version::V_1_##v; \
145#undef MIN_VERSION_CASE
147 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
151 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
161spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
162 if (operands.size() != 1)
163 return emitError(unknownLoc,
"OpCapability must have one parameter");
165 auto cap = spirv::symbolizeCapability(operands[0]);
167 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
169 capabilities.insert(*cap);
173LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
177 "OpExtension must have a literal string for the extension name");
180 unsigned wordIndex = 0;
182 if (wordIndex != words.size())
184 "unexpected trailing words in OpExtension instruction");
185 auto ext = spirv::symbolizeExtension(extName);
187 return emitError(unknownLoc,
"unknown extension: ") << extName;
189 extensions.insert(*ext);
194spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
195 if (words.size() < 2) {
197 "OpExtInstImport must have a result <id> and a literal "
198 "string for the extended instruction set name");
201 unsigned wordIndex = 1;
203 if (wordIndex != words.size()) {
205 "unexpected trailing words in OpExtInstImport");
210void spirv::Deserializer::attachVCETriple() {
212 spirv::ModuleOp::getVCETripleAttrName(),
214 extensions.getArrayRef(), context));
218spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
219 if (operands.size() != 2)
220 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
223 module->getAddressingModelAttrName(),
224 opBuilder.getAttr<spirv::AddressingModelAttr>(
225 static_cast<spirv::AddressingModel
>(operands.front())));
227 (*module)->setAttr(module->getMemoryModelAttrName(),
228 opBuilder.getAttr<spirv::MemoryModelAttr>(
229 static_cast<spirv::MemoryModel
>(operands.back())));
234template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
238 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
239 if (words.size() != 4) {
240 return emitError(loc,
"OpDecorate with ")
241 << decorationName <<
" needs a cache control integer literal and a "
242 << cacheControlKind <<
" cache control literal";
244 unsigned cacheLevel = words[2];
245 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
246 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
249 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
250 llvm::append_range(attrs, attrList);
251 attrs.push_back(value);
252 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
256LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
260 if (words.size() < 2) {
262 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
264 auto decorationName =
265 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
266 if (decorationName.empty()) {
267 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
269 auto symbol = getSymbolDecoration(decorationName);
270 switch (
static_cast<spirv::Decoration
>(words[1])) {
271 case spirv::Decoration::FPFastMathMode:
272 if (words.size() != 3) {
273 return emitError(unknownLoc,
"OpDecorate with ")
274 << decorationName <<
" needs a single integer literal";
276 decorations[words[0]].set(
277 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
278 static_cast<FPFastMathMode
>(words[2])));
280 case spirv::Decoration::FPRoundingMode:
281 if (words.size() != 3) {
282 return emitError(unknownLoc,
"OpDecorate with ")
283 << decorationName <<
" needs a single integer literal";
285 decorations[words[0]].set(
286 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
287 static_cast<FPRoundingMode
>(words[2])));
289 case spirv::Decoration::DescriptorSet:
290 case spirv::Decoration::Binding:
291 case spirv::Decoration::Location:
292 case spirv::Decoration::SpecId:
293 case spirv::Decoration::Index:
294 case spirv::Decoration::Offset:
295 case spirv::Decoration::XfbBuffer:
296 case spirv::Decoration::XfbStride:
297 if (words.size() != 3) {
298 return emitError(unknownLoc,
"OpDecorate with ")
299 << decorationName <<
" needs a single integer literal";
301 decorations[words[0]].set(
302 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
304 case spirv::Decoration::BuiltIn:
305 if (words.size() != 3) {
306 return emitError(unknownLoc,
"OpDecorate with ")
307 << decorationName <<
" needs a single integer literal";
309 decorations[words[0]].set(
310 symbol, opBuilder.getStringAttr(
311 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
313 case spirv::Decoration::ArrayStride:
314 if (words.size() != 3) {
315 return emitError(unknownLoc,
"OpDecorate with ")
316 << decorationName <<
" needs a single integer literal";
318 typeDecorations[words[0]] = words[2];
320 case spirv::Decoration::LinkageAttributes: {
321 if (words.size() < 4) {
322 return emitError(unknownLoc,
"OpDecorate with ")
324 <<
" needs at least 1 string and 1 integer literal";
332 unsigned wordIndex = 2;
334 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
335 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
336 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
337 StringAttr::get(context, linkageName), linkageTypeAttr);
338 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
341 case spirv::Decoration::Aliased:
342 case spirv::Decoration::AliasedPointer:
343 case spirv::Decoration::Block:
344 case spirv::Decoration::BufferBlock:
345 case spirv::Decoration::Flat:
346 case spirv::Decoration::NonReadable:
347 case spirv::Decoration::NonWritable:
348 case spirv::Decoration::NoPerspective:
349 case spirv::Decoration::NoSignedWrap:
350 case spirv::Decoration::NoUnsignedWrap:
351 case spirv::Decoration::RelaxedPrecision:
352 case spirv::Decoration::Restrict:
353 case spirv::Decoration::RestrictPointer:
354 case spirv::Decoration::NoContraction:
355 case spirv::Decoration::Constant:
356 case spirv::Decoration::Invariant:
357 case spirv::Decoration::Patch:
358 case spirv::Decoration::Coherent:
359 if (words.size() != 2) {
360 return emitError(unknownLoc,
"OpDecorate with ")
361 << decorationName <<
" needs a single target <id>";
363 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
365 case spirv::Decoration::CacheControlLoadINTEL: {
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
374 case spirv::Decoration::CacheControlStoreINTEL: {
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
383 case spirv::Decoration::AlignmentId:
384 case spirv::Decoration::MaxByteOffsetId:
385 case spirv::Decoration::CounterBuffer:
386 if (words.size() != 3) {
387 return emitError(unknownLoc,
"OpDecorateId with ")
388 << decorationName <<
" needs a single <id> operand";
390 pendingIdDecorations.push_back({words[0],
391 static_cast<spirv::Decoration
>(words[1]),
392 words[2], unknownLoc});
395 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
400LogicalResult spirv::Deserializer::resolveDeferredIdDecorations() {
401 for (
const DeferredIdDecoration &entry : pendingIdDecorations) {
402 StringRef decorationName = stringifyDecoration(entry.decoration);
403 StringAttr symbol = getSymbolDecoration(decorationName);
407 StringRef operandSymName;
408 if (spirv::GlobalVariableOp varOp =
409 globalVariableMap.lookup(entry.operandID))
410 operandSymName = varOp.getSymName();
411 else if (spirv::SpecConstantOp specOp =
412 specConstMap.lookup(entry.operandID))
413 operandSymName = specOp.getSymName();
415 return emitError(entry.loc,
"OpDecorateId with ")
416 << decorationName <<
" references <id> " << entry.operandID
417 <<
" which is not a global variable or specialization constant";
424 Operation *targetOp =
nullptr;
425 if (spirv::GlobalVariableOp varOp =
426 globalVariableMap.lookup(entry.targetID))
428 else if (spirv::SpecConstantOp specOp = specConstMap.lookup(entry.targetID))
430 else if (spirv::FuncOp fnOp = funcMap.lookup(entry.targetID))
432 else if (Value v = valueMap.lookup(entry.targetID))
433 targetOp = v.getDefiningOp();
436 return emitError(entry.loc,
"OpDecorateId with ")
437 << decorationName <<
" references unknown target <id> "
440 targetOp->
setAttr(symbol, symRef);
446spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
448 if (words.size() < 3) {
450 "OpMemberDecorate must have at least 3 operands");
453 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
454 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
456 " missing offset specification in OpMemberDecorate with "
457 "Offset decoration");
459 ArrayRef<uint32_t> decorationOperands;
460 if (words.size() > 3) {
461 decorationOperands = words.slice(3);
463 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
467LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
468 if (words.size() < 3) {
469 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
471 unsigned wordIndex = 2;
473 if (wordIndex != words.size()) {
475 "unexpected trailing words in OpMemberName instruction");
477 memberNameMap[words[0]][words[1]] = name;
483 if (!decorations.contains(argID)) {
484 argAttrs[argIndex] = DictionaryAttr::get(context, {});
488 spirv::DecorationAttr foundDecorationAttr;
490 for (
auto decoration :
491 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
492 spirv::Decoration::AliasedPointer,
493 spirv::Decoration::RestrictPointer}) {
495 if (decAttr.getName() !=
499 if (foundDecorationAttr)
501 "more than one Aliased/Restrict decorations for "
502 "function argument with result <id> ")
505 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
510 spirv::Decoration::RelaxedPrecision))) {
515 if (foundDecorationAttr)
516 return emitError(unknownLoc,
"already found a decoration for function "
517 "argument with result <id> ")
520 foundDecorationAttr = spirv::DecorationAttr::get(
521 context, spirv::Decoration::RelaxedPrecision);
525 if (!foundDecorationAttr)
526 return emitError(unknownLoc,
"unimplemented decoration support for "
527 "function argument with result <id> ")
530 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
531 foundDecorationAttr);
532 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
539 return emitError(unknownLoc,
"found function inside function");
543 if (operands.size() != 4) {
544 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
548 return emitError(unknownLoc,
"undefined result type from <id> ")
552 uint32_t fnID = operands[1];
553 if (funcMap.count(fnID)) {
554 return emitError(unknownLoc,
"duplicate function definition/declaration");
557 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
559 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
563 if (!fnType || !isa<FunctionType>(fnType)) {
564 return emitError(unknownLoc,
"unknown function type from <id> ")
567 auto functionType = cast<FunctionType>(fnType);
569 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
570 (functionType.getNumResults() == 1 &&
571 functionType.getResult(0) != resultType)) {
572 return emitError(unknownLoc,
"mismatch in function type ")
573 << functionType <<
" and return type " << resultType <<
" specified";
577 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
578 functionType, fnControl.value());
580 if (decorations.count(fnID)) {
581 for (
auto attr : decorations[fnID].getAttrs()) {
582 funcOp->setAttr(attr.getName(), attr.getValue());
585 curFunction = funcMap[fnID] = funcOp;
586 auto *entryBlock = funcOp.addEntryBlock();
589 <<
"//===-------------------------------------------===//\n";
590 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
591 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
592 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
593 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
598 argAttrs.resize(functionType.getNumInputs());
601 if (functionType.getNumInputs()) {
602 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
603 auto argType = functionType.getInput(i);
604 spirv::Opcode opcode = spirv::Opcode::OpNop;
607 spirv::Opcode::OpFunctionParameter))) {
610 if (opcode != spirv::Opcode::OpFunctionParameter) {
613 "missing OpFunctionParameter instruction for argument ")
616 if (operands.size() != 2) {
619 "expected result type and result <id> for OpFunctionParameter");
621 auto argDefinedType =
getType(operands[0]);
622 if (!argDefinedType || argDefinedType != argType) {
624 "mismatch in argument type between function type "
626 << functionType <<
" and argument type definition "
627 << argDefinedType <<
" at argument " << i;
630 return emitError(unknownLoc,
"duplicate definition of result <id> ")
637 auto argValue = funcOp.getArgument(i);
638 valueMap[operands[1]] = argValue;
642 if (llvm::any_of(argAttrs, [](
Attribute attr) {
643 auto argAttr = cast<DictionaryAttr>(attr);
644 return !argAttr.empty();
646 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
651 auto linkageAttr = funcOp.getLinkageAttributes();
652 auto hasImportLinkage =
653 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
654 spirv::LinkageType::Import);
655 if (hasImportLinkage)
662 spirv::Opcode opcode = spirv::Opcode::OpNop;
671 spirv::Opcode::OpFunctionEnd))) {
674 if (opcode == spirv::Opcode::OpFunctionEnd) {
677 if (opcode != spirv::Opcode::OpLabel) {
678 return emitError(unknownLoc,
"a basic block must start with OpLabel");
680 if (instOperands.size() != 1) {
681 return emitError(unknownLoc,
"OpLabel should only have result <id>");
683 blockMap[instOperands[0]] = entryBlock;
691 spirv::Opcode::OpFunctionEnd)) &&
692 opcode != spirv::Opcode::OpFunctionEnd) {
697 if (opcode != spirv::Opcode::OpFunctionEnd) {
707 if (!operands.empty()) {
708 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
719 curFunction = std::nullopt;
724 <<
"//===-------------------------------------------===//\n";
731 if (operands.size() < 2) {
733 "missing graph defintion in OpGraphEntryPointARM");
736 unsigned wordIndex = 0;
737 uint32_t graphID = operands[wordIndex++];
738 if (!graphMap.contains(graphID)) {
740 "missing graph definition/declaration with id ")
744 spirv::GraphARMOp graphARM = graphMap[graphID];
746 graphARM.setSymName(name);
747 graphARM.setEntryPoint(
true);
750 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
752 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
754 return emitError(unknownLoc,
"undefined result <id> ")
755 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
761 opBuilder.setInsertionPoint(graphARM);
762 spirv::GraphEntryPointARMOp::create(
763 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
764 opBuilder.getArrayAttr(interface));
772 return emitError(unknownLoc,
"found graph inside graph");
775 if (operands.size() < 2) {
776 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
780 if (!type || !isa<GraphType>(type)) {
781 return emitError(unknownLoc,
"unknown graph type from <id> ")
784 auto graphType = cast<GraphType>(type);
785 if (graphType.getNumResults() <= 0) {
786 return emitError(unknownLoc,
"expected at least one result");
789 uint32_t graphID = operands[1];
790 if (graphMap.count(graphID)) {
791 return emitError(unknownLoc,
"duplicate graph definition/declaration");
796 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
797 curGraph = graphMap[graphID] = graphOp;
798 Block *entryBlock = graphOp.addEntryBlock();
801 <<
"//===-------------------------------------------===//\n";
802 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
803 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
804 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
805 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
810 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
811 spirv::Opcode opcode;
814 spirv::Opcode::OpGraphInputARM))) {
817 if (operands.size() != 3) {
818 return emitError(unknownLoc,
"expected result type, result <id> and "
819 "input index for OpGraphInputARM");
823 if (!argDefinedType) {
824 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
827 if (argDefinedType != argType) {
829 "mismatch in argument type between graph type "
831 << graphType <<
" and argument type definition " << argDefinedType
832 <<
" at argument " <<
index;
835 return emitError(unknownLoc,
"duplicate definition of result <id> ")
840 if (!inputIndexAttr) {
842 "unable to read inputIndex value from constant op ")
845 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
846 valueMap[operands[1]] = argValue;
849 graphOutputs.resize(graphType.getNumResults());
855 blockMap[graphID] = entryBlock;
862 spirv::Opcode opcode;
872 }
while (opcode != spirv::Opcode::OpGraphEndARM);
879 if (operands.size() != 2) {
882 "expected value id and output index for OpGraphSetOutputARM");
885 uint32_t
id = operands[0];
888 return emitError(unknownLoc,
"could not find result <id> ") << id;
892 if (!outputIndexAttr) {
894 "unable to read outputIndex value from constant op ")
897 graphOutputs[outputIndexAttr.getInt()] = value;
904 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
907 if (!operands.empty()) {
908 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
912 curGraph = std::nullopt;
913 graphOutputs.clear();
918 <<
"//===-------------------------------------------===//\n";
923std::optional<std::pair<Attribute, Type>>
925 auto constIt = constantMap.find(
id);
926 if (constIt == constantMap.end())
928 return constIt->getSecond();
931std::optional<std::pair<Attribute, Type>>
933 if (
auto it = constantCompositeReplicateMap.find(
id);
934 it != constantCompositeReplicateMap.end())
939std::optional<spirv::SpecConstOperationMaterializationInfo>
941 auto constIt = specConstOperationMap.find(
id);
942 if (constIt == specConstOperationMap.end())
944 return constIt->getSecond();
948 auto funcName = nameMap.lookup(
id).str();
949 if (funcName.empty()) {
950 funcName =
"spirv_fn_" + std::to_string(
id);
956 std::string graphName = nameMap.lookup(
id).str();
957 if (graphName.empty()) {
958 graphName =
"spirv_graph_" + std::to_string(
id);
964 auto constName = nameMap.lookup(
id).str();
965 if (constName.empty()) {
966 constName =
"spirv_spec_const_" + std::to_string(
id);
973 TypedAttr defaultValue) {
975 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
977 if (decorations.count(resultID)) {
978 for (
auto attr : decorations[resultID].getAttrs())
979 op->setAttr(attr.getName(), attr.getValue());
981 specConstMap[resultID] = op;
985std::optional<spirv::GraphConstantARMOpMaterializationInfo>
987 auto graphConstIt = graphConstantMap.find(
id);
988 if (graphConstIt == graphConstantMap.end())
990 return graphConstIt->getSecond();
995 unsigned wordIndex = 0;
996 if (operands.size() < 3) {
999 "OpVariable needs at least 3 operands, type, <id> and storage class");
1003 auto type =
getType(operands[wordIndex]);
1005 return emitError(unknownLoc,
"unknown result type <id> : ")
1006 << operands[wordIndex];
1008 auto ptrType = dyn_cast<spirv::PointerType>(type);
1011 "expected a result type <id> to be a spirv.ptr, found : ")
1017 auto variableID = operands[wordIndex];
1018 auto variableName = nameMap.lookup(variableID).str();
1019 if (variableName.empty()) {
1020 variableName =
"spirv_var_" + std::to_string(variableID);
1025 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
1026 if (ptrType.getStorageClass() != storageClass) {
1027 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
1028 << type <<
" and that specified in OpVariable instruction : "
1029 << stringifyStorageClass(storageClass);
1036 if (wordIndex < operands.size()) {
1046 return emitError(unknownLoc,
"unknown <id> ")
1047 << operands[wordIndex] <<
"used as initializer";
1049 initializer = SymbolRefAttr::get(op);
1052 if (wordIndex != operands.size()) {
1054 "found more operands than expected when deserializing "
1055 "OpVariable instruction, only ")
1056 << wordIndex <<
" of " << operands.size() <<
" processed";
1059 auto varOp = spirv::GlobalVariableOp::create(
1060 opBuilder, loc, TypeAttr::get(type),
1061 opBuilder.getStringAttr(variableName), initializer);
1064 if (decorations.count(variableID)) {
1065 for (
auto attr : decorations[variableID].getAttrs())
1066 varOp->setAttr(attr.getName(), attr.getValue());
1068 globalVariableMap[variableID] = varOp;
1077 return dyn_cast<IntegerAttr>(constInfo->first);
1081 if (operands.size() < 2) {
1082 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1085 unsigned wordIndex = 1;
1087 if (wordIndex != operands.size()) {
1089 "unexpected trailing words in OpName instruction");
1094 nameMap.emplace_or_assign(operands[0], name);
1105 if (operands.empty()) {
1106 return emitError(unknownLoc,
"type instruction with opcode ")
1107 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1112 if (typeMap.count(operands[0])) {
1113 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1118 case spirv::Opcode::OpTypeVoid:
1119 if (operands.size() != 1)
1120 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1121 typeMap[operands[0]] = opBuilder.getNoneType();
1123 case spirv::Opcode::OpTypeBool:
1124 if (operands.size() != 1)
1125 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1126 typeMap[operands[0]] = opBuilder.getI1Type();
1128 case spirv::Opcode::OpTypeInt: {
1129 if (operands.size() != 3)
1131 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1140 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1141 : IntegerType::SignednessSemantics::Signless;
1142 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1144 case spirv::Opcode::OpTypeFloat: {
1145 if (operands.size() != 2 && operands.size() != 3)
1147 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1148 "or 3 operands (type, bitwidth, encoding), but got ")
1150 uint32_t bitWidth = operands[1];
1153 if (operands.size() == 2) {
1156 floatTy = opBuilder.getF16Type();
1159 floatTy = opBuilder.getF32Type();
1162 floatTy = opBuilder.getF64Type();
1165 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1170 if (operands.size() == 3) {
1171 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1173 floatTy = opBuilder.getBF16Type();
1174 else if (spirv::FPEncoding(operands[2]) ==
1175 spirv::FPEncoding::Float8E4M3EXT &&
1177 floatTy = opBuilder.getF8E4M3FNType();
1178 else if (spirv::FPEncoding(operands[2]) ==
1179 spirv::FPEncoding::Float8E5M2EXT &&
1181 floatTy = opBuilder.getF8E5M2Type();
1183 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1184 << operands[2] <<
" and bitWidth " << bitWidth;
1187 typeMap[operands[0]] = floatTy;
1189 case spirv::Opcode::OpTypeVector: {
1190 if (operands.size() != 3) {
1193 "OpTypeVector must have element type and count parameters");
1197 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1200 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1202 case spirv::Opcode::OpTypePointer: {
1205 case spirv::Opcode::OpTypeArray:
1207 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1209 case spirv::Opcode::OpTypeFunction:
1211 case spirv::Opcode::OpTypeImage:
1213 case spirv::Opcode::OpTypeSampler:
1215 case spirv::Opcode::OpTypeNamedBarrier:
1217 case spirv::Opcode::OpTypeSampledImage:
1219 case spirv::Opcode::OpTypeRuntimeArray:
1221 case spirv::Opcode::OpTypeStruct:
1223 case spirv::Opcode::OpTypeMatrix:
1225 case spirv::Opcode::OpTypeTensorARM:
1227 case spirv::Opcode::OpTypeGraphARM:
1230 return emitError(unknownLoc,
"unhandled type instruction");
1237 if (operands.size() != 3)
1238 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1240 auto pointeeType =
getType(operands[2]);
1242 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1245 uint32_t typePointerID = operands[0];
1246 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1249 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1250 deferredStructIt != std::end(deferredStructTypesInfos);) {
1251 for (
auto *unresolvedMemberIt =
1252 std::begin(deferredStructIt->unresolvedMemberTypes);
1253 unresolvedMemberIt !=
1254 std::end(deferredStructIt->unresolvedMemberTypes);) {
1255 if (unresolvedMemberIt->first == typePointerID) {
1259 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1260 typeMap[typePointerID];
1261 unresolvedMemberIt =
1262 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1264 ++unresolvedMemberIt;
1268 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1270 auto structType = deferredStructIt->deferredStructType;
1272 assert(structType &&
"expected a spirv::StructType");
1273 assert(structType.isIdentified() &&
"expected an indentified struct");
1275 if (failed(structType.trySetBody(
1276 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1277 deferredStructIt->memberDecorationsInfo,
1278 deferredStructIt->structDecorationsInfo)))
1281 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1292 if (operands.size() != 3) {
1294 "OpTypeArray must have element type and count parameters");
1299 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1307 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1308 << operands[2] <<
"can only come from normal constant right now";
1311 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1312 count = intVal.getValue().getZExtValue();
1314 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1315 "scalar integer constant instruction");
1319 elementTy, count, typeDecorations.lookup(operands[0]));
1325 assert(!operands.empty() &&
"No operands for processing function type");
1326 if (operands.size() == 1) {
1327 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1329 auto returnType =
getType(operands[1]);
1331 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1334 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1335 auto ty =
getType(operands[i]);
1337 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1339 argTypes.push_back(ty);
1345 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1351 if (operands.size() != 6) {
1353 "OpTypeCooperativeMatrixKHR must have element type, "
1354 "scope, row and column parameters, and use");
1360 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1364 std::optional<spirv::Scope> scope =
1369 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1378 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1379 "undefined constant <id> ")
1383 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1384 "references undefined constant <id> ")
1388 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1389 "undefined constant <id> ")
1392 unsigned rows = rowsAttr.getInt();
1393 unsigned columns = columnsAttr.getInt();
1395 std::optional<spirv::CooperativeMatrixUseKHR> use =
1396 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1400 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1404 typeMap[operands[0]] =
1411 if (operands.size() != 2) {
1412 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1417 "OpTypeRuntimeArray references undefined <id> ")
1421 memberType, typeDecorations.lookup(operands[0]));
1429 if (operands.empty()) {
1430 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1433 if (operands.size() == 1) {
1435 typeMap[operands[0]] =
1444 for (
auto op : llvm::drop_begin(operands, 1)) {
1446 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1448 if (!memberType && !typeForwardPtr)
1449 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1453 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1455 memberTypes.push_back(memberType);
1460 if (memberDecorationMap.count(operands[0])) {
1461 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1462 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1463 if (allMemberDecorations.count(memberIndex)) {
1464 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1466 if (memberDecoration.first == spirv::Decoration::Offset) {
1468 if (offsetInfo.empty()) {
1469 offsetInfo.resize(memberTypes.size());
1471 offsetInfo[memberIndex] = memberDecoration.second[0];
1473 auto intType = mlir::IntegerType::get(context, 32);
1474 if (!memberDecoration.second.empty()) {
1475 memberDecorationsInfo.emplace_back(
1476 memberIndex, memberDecoration.first,
1477 IntegerAttr::get(intType, memberDecoration.second[0]));
1479 memberDecorationsInfo.emplace_back(
1480 memberIndex, memberDecoration.first, UnitAttr::get(context));
1489 if (decorations.count(operands[0])) {
1492 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1493 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1494 assert(decoration.has_value());
1495 structDecorationsInfo.emplace_back(decoration.value(),
1496 decorationAttr.getValue());
1500 uint32_t structID = operands[0];
1501 std::string structIdentifier = nameMap.lookup(structID).str();
1503 if (structIdentifier.empty()) {
1504 assert(unresolvedMemberTypes.empty() &&
1505 "didn't expect unresolved member types");
1507 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1510 typeMap[structID] = structTy;
1512 if (!unresolvedMemberTypes.empty())
1513 deferredStructTypesInfos.push_back(
1514 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1515 memberDecorationsInfo, structDecorationsInfo});
1516 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1517 memberDecorationsInfo,
1518 structDecorationsInfo)))
1529 if (operands.size() != 3) {
1531 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1532 " (result_id, column_type, and column_count)");
1538 "OpTypeMatrix references undefined column type.")
1542 uint32_t colsCount = operands[2];
1549 unsigned size = operands.size();
1550 if (size < 2 || size > 4)
1551 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1552 "(result_id, element_type, (rank), (shape)) ")
1558 "OpTypeTensorARM references undefined element type ")
1568 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1569 "scalar integer constant instruction");
1570 unsigned rank = rankAttr.getValue().getZExtValue();
1577 std::optional<std::pair<Attribute, Type>> shapeInfo =
1580 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1581 "constant instruction of type OpTypeArray");
1583 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1585 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1586 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1588 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1590 shape.push_back(dimIntAttr.getValue().getSExtValue());
1598 unsigned size = operands.size();
1600 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1601 "(result_id, num_inputs, (inout0_type, "
1602 "inout1_type, ...))")
1605 uint32_t numInputs = operands[1];
1608 for (
unsigned i = 2; i < size; ++i) {
1612 "OpTypeGraphARM references undefined element type.")
1615 if (i - 2 >= numInputs) {
1616 returnTypes.push_back(inOutTy);
1618 argTypes.push_back(inOutTy);
1621 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1627 if (operands.size() != 2)
1629 "OpTypeForwardPointer instruction must have two operands");
1631 typeForwardPointerIDs.insert(operands[0]);
1641 if (operands.size() != 8)
1644 "OpTypeImage with non-eight operands are not supported yet");
1648 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1651 auto dim = spirv::symbolizeDim(operands[2]);
1653 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1656 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1658 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1661 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1663 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1666 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1668 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1670 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1671 if (!samplerUseInfo)
1672 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1675 auto format = spirv::symbolizeImageFormat(operands[7]);
1677 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1681 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1682 samplingInfo.value(), samplerUseInfo.value(), format.value());
1688 if (operands.size() != 2)
1689 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1694 "OpTypeSampledImage references undefined <id>: ")
1703 if (operands.size() != 1)
1704 return emitError(unknownLoc,
"OpTypeSampler must have no parameters");
1712 if (operands.size() != 1)
1713 return emitError(unknownLoc,
"OpTypeNamedBarrier must have no parameters");
1725 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1727 if (operands.size() < 2) {
1729 << opname <<
" must have type <id> and result <id>";
1731 if (operands.size() < 3) {
1733 << opname <<
" must have at least 1 more parameter";
1738 return emitError(unknownLoc,
"undefined result type from <id> ")
1742 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1743 if (bitwidth == 64) {
1744 if (operands.size() == 4) {
1748 << opname <<
" should have 2 parameters for 64-bit values";
1750 if (bitwidth <= 32) {
1751 if (operands.size() == 3) {
1757 <<
" should have 1 parameter for values with no more than 32 bits";
1759 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1763 auto resultID = operands[1];
1765 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1766 auto bitwidth = intType.getWidth();
1767 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1772 if (bitwidth == 64) {
1779 } words = {operands[2], operands[3]};
1780 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1781 }
else if (bitwidth <= 32) {
1782 value = APInt(bitwidth, operands[2],
true,
1786 auto attr = opBuilder.getIntegerAttr(intType, value);
1793 constantMap.try_emplace(resultID, attr, intType);
1799 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1800 auto bitwidth = floatType.getWidth();
1801 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1806 if (floatType.isF64()) {
1813 } words = {operands[2], operands[3]};
1814 value = APFloat(llvm::bit_cast<double>(words));
1815 }
else if (floatType.isF32()) {
1816 value = APFloat(llvm::bit_cast<float>(operands[2]));
1817 }
else if (floatType.isF16()) {
1818 APInt data(16, operands[2]);
1819 value = APFloat(APFloat::IEEEhalf(), data);
1820 }
else if (floatType.isBF16()) {
1821 APInt data(16, operands[2]);
1822 value = APFloat(APFloat::BFloat(), data);
1823 }
else if (floatType.isF8E4M3FN()) {
1824 APInt data(8, operands[2]);
1825 value = APFloat(APFloat::Float8E4M3FN(), data);
1826 }
else if (floatType.isF8E5M2()) {
1827 APInt data(8, operands[2]);
1828 value = APFloat(APFloat::Float8E5M2(), data);
1831 auto attr = opBuilder.getFloatAttr(floatType, value);
1837 constantMap.try_emplace(resultID, attr, floatType);
1843 return emitError(unknownLoc,
"OpConstant can only generate values of "
1844 "scalar integer or floating-point type");
1849 if (operands.size() != 2) {
1851 << (isSpec ?
"Spec" :
"") <<
"Constant"
1852 << (isTrue ?
"True" :
"False")
1853 <<
" must have type <id> and result <id>";
1856 auto attr = opBuilder.getBoolAttr(isTrue);
1857 auto resultID = operands[1];
1863 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1871 if (operands.size() < 2) {
1873 "OpConstantComposite must have type <id> and result <id>");
1875 if (operands.size() < 3) {
1877 "OpConstantComposite must have at least 1 parameter");
1882 return emitError(unknownLoc,
"undefined result type from <id> ")
1887 elements.reserve(operands.size() - 2);
1888 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1891 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1892 << operands[i] <<
" must come from a normal constant";
1894 elements.push_back(elementInfo->first);
1897 auto resultID = operands[1];
1898 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1901 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1902 for (
auto value : denseElemAttr.getValues<
Attribute>())
1903 flattenedElems.push_back(value);
1905 flattenedElems.push_back(element);
1909 constantMap.try_emplace(resultID, attr, tensorType);
1910 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1914 constantMap.try_emplace(resultID, attr, shapedType);
1915 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1916 auto attr = opBuilder.getArrayAttr(elements);
1917 constantMap.try_emplace(resultID, attr, resultType);
1919 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1928 if (operands.size() != 3) {
1931 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1937 return emitError(unknownLoc,
"undefined result type from <id> ")
1941 auto compositeType = dyn_cast<CompositeType>(resultType);
1942 if (!compositeType) {
1944 "result type from <id> is not a composite type")
1948 uint32_t resultID = operands[1];
1949 uint32_t constantID = operands[2];
1951 std::optional<std::pair<Attribute, Type>> constantInfo =
1953 if (constantInfo.has_value()) {
1954 constantCompositeReplicateMap.try_emplace(
1955 resultID, constantInfo.value().first, resultType);
1959 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1961 if (replicatedConstantCompositeInfo.has_value()) {
1962 constantCompositeReplicateMap.try_emplace(
1963 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1967 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1969 <<
" must come from a normal constant or a "
1970 "OpConstantCompositeReplicateEXT";
1975 if (operands.size() < 2) {
1978 "OpSpecConstantComposite must have type <id> and result <id>");
1980 if (operands.size() < 3) {
1982 "OpSpecConstantComposite must have at least 1 parameter");
1987 return emitError(unknownLoc,
"undefined result type from <id> ")
1991 auto resultID = operands[1];
1995 elements.reserve(operands.size() - 2);
1996 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1998 elements.push_back(SymbolRefAttr::get(elementInfo));
2001 auto op = spirv::SpecConstantCompositeOp::create(
2002 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
2003 opBuilder.getArrayAttr(elements));
2004 specConstCompositeMap[resultID] = op;
2011 if (operands.size() != 3) {
2012 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
2013 "3 operands but found ")
2019 return emitError(unknownLoc,
"undefined result type from <id> ")
2023 auto compositeType = dyn_cast<CompositeType>(resultType);
2024 if (!compositeType) {
2026 "result type from <id> is not a composite type")
2030 uint32_t resultID = operands[1];
2033 spirv::SpecConstantOp constituentSpecConstantOp =
2035 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
2036 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
2037 SymbolRefAttr::get(constituentSpecConstantOp));
2039 specConstCompositeReplicateMap[resultID] = op;
2046 if (operands.size() < 3)
2047 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
2048 "result <id>, and operand opcode");
2050 uint32_t resultTypeID = operands[0];
2053 return emitError(unknownLoc,
"undefined result type from <id> ")
2056 uint32_t resultID = operands[1];
2057 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
2058 auto emplaceResult = specConstOperationMap.try_emplace(
2061 enclosedOpcode, resultTypeID,
2064 if (!emplaceResult.second)
2065 return emitError(unknownLoc,
"value with <id>: ")
2066 << resultID <<
" is probably defined before.";
2072 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2088 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2089 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
2092 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2093 enclosedOpResultTypeAndOperands.push_back(fakeID);
2094 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2095 enclosedOpOperands.end());
2110 auto specConstOperationOp =
2111 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2113 Region &body = specConstOperationOp.getBody();
2115 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2122 opBuilder.setInsertionPointToEnd(&block);
2124 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2125 return specConstOperationOp.getResult();
2130 if (operands.size() != 2) {
2132 "OpConstantNull must only have type <id> and result <id>");
2137 return emitError(unknownLoc,
"undefined result type from <id> ")
2141 auto resultID = operands[1];
2143 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2144 attr = opBuilder.getZeroAttr(resultType);
2145 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2146 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2153 constantMap.try_emplace(resultID, attr, resultType);
2157 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2163 if (operands.size() < 3) {
2165 <<
"OpGraphConstantARM must have at least 2 operands";
2170 return emitError(unknownLoc,
"undefined result type from <id> ")
2174 uint32_t resultID = operands[1];
2176 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2177 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2180 APInt graph_constant_id = APInt(32, operands[2],
true);
2181 Type i32Ty = opBuilder.getIntegerType(32);
2182 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2183 graphConstantMap.try_emplace(
2195 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2196 <<
" @ " << block <<
"\n");
2203 auto *block = curFunction->addBlock();
2204 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2205 <<
" @ " << block <<
"\n");
2206 return blockMap[id] = block;
2211 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2214 if (operands.size() != 1) {
2215 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2223 spirv::BranchOp::create(opBuilder, loc,
target);
2233 "OpBranchConditional must appear inside a block");
2236 if (operands.size() != 3 && operands.size() != 5) {
2238 "OpBranchConditional must have condition, true label, "
2239 "false label, and optionally two branch weights");
2242 auto condition =
getValue(operands[0]);
2246 std::optional<std::pair<uint32_t, uint32_t>> weights;
2247 if (operands.size() == 5) {
2248 weights = std::make_pair(operands[3], operands[4]);
2254 spirv::BranchConditionalOp::create(
2255 opBuilder, loc, condition, trueBlock,
2265 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2268 if (operands.size() != 1) {
2269 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2272 auto labelID = operands[0];
2275 LLVM_DEBUG(logger.startLine()
2276 <<
"[block] populating block " << block <<
"\n");
2278 assert(block->empty() &&
"re-deserialize the same block!");
2280 opBuilder.setInsertionPointToStart(block);
2281 blockMap[labelID] = curBlock = block;
2288 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2293 LLVM_DEBUG(logger.startLine()
2294 <<
"[block] populating block " << block <<
"\n");
2296 assert(block->
empty() &&
"re-deserialize the same block!");
2298 opBuilder.setInsertionPointToStart(block);
2299 blockMap[graphID] = curBlock = block;
2307 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2310 if (operands.size() < 2) {
2313 "OpSelectionMerge must specify merge target and selection control");
2318 auto selectionControl = operands[1];
2320 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2324 "a block cannot have more than one OpSelectionMerge instruction");
2333 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2336 if (operands.size() < 3) {
2337 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2338 "continue target and loop control");
2344 uint32_t loopControl = operands[2];
2347 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2351 "a block cannot have more than one OpLoopMerge instruction");
2359 return emitError(unknownLoc,
"OpPhi must appear in a block");
2362 if (operands.size() < 4) {
2363 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2364 "and variable-parent pairs");
2369 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2370 valueMap[operands[1]] = blockArg;
2371 LLVM_DEBUG(logger.startLine()
2372 <<
"[phi] created block argument " << blockArg
2373 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2377 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2378 uint32_t value = operands[i];
2380 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2381 blockPhiInfo[predecessorTargetPair].push_back(value);
2382 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2383 <<
" with arg id = " << value <<
"\n");
2391 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2393 if (operands.size() < 2)
2394 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2395 "a default target");
2397 if (operands.size() % 2)
2399 "OpSwitch must at have an even number of operands: "
2400 "selector, default target and any number of literal and "
2401 "label <id> pairs");
2409 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2410 literals.push_back(operands[i]);
2415 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2424class ControlFlowStructurizer {
2427 ControlFlowStructurizer(
Location loc, uint32_t control,
2430 llvm::ScopedPrinter &logger)
2431 : location(loc), control(control), blockMergeInfo(mergeInfo),
2432 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2435 ControlFlowStructurizer(
Location loc, uint32_t control,
2438 : location(loc), control(control), blockMergeInfo(mergeInfo),
2439 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2449 LogicalResult structurize();
2454 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2457 spirv::LoopOp createLoopOp(uint32_t loopControl);
2460 void collectBlocksInConstruct();
2469 Block *continueBlock;
2475 llvm::ScopedPrinter &logger;
2481ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2484 OpBuilder builder(&mergeBlock->front());
2486 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2487 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2488 selectionOp.addMergeBlock(builder);
2493spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2496 OpBuilder builder(&mergeBlock->front());
2498 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2499 auto loopOp = spirv::LoopOp::create(builder, location, control);
2500 loopOp.addEntryAndMergeBlock(builder);
2505void ControlFlowStructurizer::collectBlocksInConstruct() {
2506 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2509 constructBlocks.insert(headerBlock);
2513 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2514 for (
auto *successor : constructBlocks[i]->getSuccessors())
2515 if (successor != mergeBlock)
2516 constructBlocks.insert(successor);
2520LogicalResult ControlFlowStructurizer::structurize() {
2521 Operation *op =
nullptr;
2522 bool isLoop = continueBlock !=
nullptr;
2524 if (
auto loopOp = createLoopOp(control))
2525 op = loopOp.getOperation();
2527 if (
auto selectionOp = createSelectionOp(control))
2528 op = selectionOp.getOperation();
2537 mapper.
map(mergeBlock, &body.
back());
2539 collectBlocksInConstruct();
2560 OpBuilder builder(body);
2561 for (
auto *block : constructBlocks) {
2564 auto *newBlock = builder.createBlock(&body.
back());
2565 mapper.
map(block, newBlock);
2566 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2567 <<
" from block " << block <<
"\n");
2569 for (BlockArgument blockArg : block->getArguments()) {
2571 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2572 mapper.
map(blockArg, newArg);
2573 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2574 << blockArg <<
" to " << newArg <<
"\n");
2577 LLVM_DEBUG(logger.startLine()
2578 <<
"[cf] block " << block <<
" is a function entry block\n");
2581 for (
auto &op : *block)
2582 newBlock->push_back(op.
clone(mapper));
2586 auto remapOperands = [&](Operation *op) {
2588 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2589 operand.set(mappedOp);
2592 succOp.set(mappedOp);
2594 for (
auto &block : body)
2595 block.walk(remapOperands);
2603 headerBlock->replaceAllUsesWith(mergeBlock);
2606 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2607 headerBlock->getParentOp()->print(logger.getOStream());
2608 logger.startLine() <<
"\n";
2612 if (!mergeBlock->args_empty()) {
2613 return mergeBlock->getParentOp()->emitError(
2614 "OpPhi in loop merge block unsupported");
2620 for (BlockArgument blockArg : headerBlock->getArguments())
2621 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2625 SmallVector<Value, 4> blockArgs;
2626 if (!headerBlock->args_empty())
2627 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2631 builder.setInsertionPointToEnd(&body.front());
2632 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2633 ArrayRef<Value>(blockArgs));
2638 SmallVector<Value> valuesToYield;
2641 SmallVector<Value> outsideUses;
2655 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2660 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2661 valuesToYield.push_back(body.back().getArguments().back());
2662 outsideUses.push_back(blockArg);
2667 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2670 for (
auto *block : constructBlocks)
2671 block->dropAllReferences();
2676 for (
Block *block : constructBlocks) {
2677 for (Operation &op : *block) {
2681 outsideUses.push_back(
result);
2684 for (BlockArgument &arg : block->getArguments()) {
2685 if (!arg.use_empty()) {
2687 outsideUses.push_back(arg);
2692 assert(valuesToYield.size() == outsideUses.size());
2696 if (!valuesToYield.empty()) {
2697 LLVM_DEBUG(logger.startLine()
2698 <<
"[cf] yielding values from the selection / loop region\n");
2701 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2702 Operation *merge = llvm::getSingleElement(mergeOps);
2704 merge->setOperands(valuesToYield);
2712 builder.setInsertionPoint(&mergeBlock->front());
2714 Operation *newOp =
nullptr;
2717 newOp = spirv::LoopOp::create(builder, location,
2719 static_cast<spirv::LoopControl
>(control));
2721 newOp = spirv::SelectionOp::create(
2723 static_cast<spirv::SelectionControl
>(control));
2733 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2734 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2740 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2747 for (
auto *block : constructBlocks) {
2748 if (!block->use_empty())
2749 return emitError(block->getParent()->getLoc(),
2750 "failed control flow structurization: "
2751 "block has uses outside of the "
2752 "enclosing selection/loop construct");
2753 for (Operation &op : *block)
2755 return op.
emitOpError(
"failed control flow structurization: value has "
2756 "uses outside of the "
2757 "enclosing selection/loop construct");
2758 for (BlockArgument &arg : block->getArguments())
2759 if (!arg.use_empty())
2760 return emitError(arg.getLoc(),
"failed control flow structurization: "
2761 "block argument has uses outside of the "
2762 "enclosing selection/loop construct");
2766 for (
auto *block : constructBlocks) {
2806 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2807 auto it = blockMergeInfo.find(block);
2808 if (it != blockMergeInfo.end()) {
2810 Location loc = it->second.loc;
2814 return emitError(loc,
"failed control flow structurization: nested "
2815 "loop header block should be remapped!");
2817 Block *newContinue = it->second.continueBlock;
2821 return emitError(loc,
"failed control flow structurization: nested "
2822 "loop continue block should be remapped!");
2825 Block *newMerge = it->second.mergeBlock;
2827 newMerge = mappedTo;
2831 blockMergeInfo.
erase(it);
2832 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2839 if (block->walk(updateMergeInfo).wasInterrupted())
2847 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2848 <<
" to only contain a spirv.Branch op\n");
2852 builder.setInsertionPointToEnd(block);
2853 spirv::BranchOp::create(builder, location, mergeBlock);
2855 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2860 LLVM_DEBUG(logger.startLine()
2861 <<
"[cf] after structurizing construct with header block "
2862 << headerBlock <<
":\n"
2871 <<
"//----- [phi] start wiring up block arguments -----//\n";
2877 for (
const auto &info : blockPhiInfo) {
2878 Block *block = info.first.first;
2882 logger.startLine() <<
"[phi] block " << block <<
"\n";
2883 logger.startLine() <<
"[phi] before creating block argument:\n";
2885 logger.startLine() <<
"\n";
2891 opBuilder.setInsertionPoint(op);
2894 blockArgs.reserve(phiInfo.size());
2895 for (uint32_t valueId : phiInfo) {
2897 blockArgs.push_back(value);
2898 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2899 <<
" id = " << valueId <<
"\n");
2901 return emitError(unknownLoc,
"OpPhi references undefined value!");
2905 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2907 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2908 branchOp.getTarget(), blockArgs);
2910 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2911 assert((branchCondOp.getTrueBlock() ==
target ||
2912 branchCondOp.getFalseBlock() ==
target) &&
2913 "expected target to be either the true or false target");
2914 if (
target == branchCondOp.getTrueTarget())
2915 spirv::BranchConditionalOp::create(
2916 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2917 blockArgs, branchCondOp.getFalseBlockArguments(),
2918 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2919 branchCondOp.getFalseTarget());
2921 spirv::BranchConditionalOp::create(
2922 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2923 branchCondOp.getTrueBlockArguments(), blockArgs,
2924 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2925 branchCondOp.getFalseBlock());
2927 branchCondOp.erase();
2928 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2929 if (
target == switchOp.getDefaultTarget()) {
2933 spirv::SwitchOp::create(
2934 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2935 switchOp.getDefaultTarget(), blockArgs, literals,
2936 switchOp.getTargets(), targetOperands);
2940 auto it = llvm::find(targets,
target);
2941 assert(it != targets.end());
2942 size_t index = std::distance(targets.begin(), it);
2943 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2946 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2950 logger.startLine() <<
"[phi] after creating block argument:\n";
2952 logger.startLine() <<
"\n";
2955 blockPhiInfo.clear();
2960 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2968 for (
auto [block, mergeInfo] : blockMergeInfoCopy) {
2970 if (mergeInfo.continueBlock)
2973 if (!block->mightHaveTerminator())
2976 Operation *terminator = block->getTerminator();
2979 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2983 bool splitHeaderMergeBlock =
false;
2984 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2985 if (mergeInfo.mergeBlock == block)
2986 splitHeaderMergeBlock =
true;
2993 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2996 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
3000 blockMergeInfo.erase(block);
3001 blockMergeInfo.try_emplace(newBlock, mergeInfo);
3009 if (!options.enableControlFlowStructurization) {
3013 <<
"//----- [cf] skip structurizing control flow -----//\n";
3021 <<
"//----- [cf] start structurizing control flow -----//\n";
3026 logger.startLine() <<
"[cf] split conditional blocks\n";
3027 logger.startLine() <<
"\n";
3034 while (!blockMergeInfo.empty()) {
3035 Block *headerBlock = blockMergeInfo.
begin()->first;
3039 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
3040 headerBlock->
print(logger.getOStream());
3041 logger.startLine() <<
"\n";
3045 assert(mergeBlock &&
"merge block cannot be nullptr");
3047 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
3049 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
3050 mergeBlock->print(logger.getOStream());
3051 logger.startLine() <<
"\n";
3055 LLVM_DEBUG(
if (continueBlock) {
3056 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
3057 continueBlock->print(logger.getOStream());
3058 logger.startLine() <<
"\n";
3062 blockMergeInfo.
erase(blockMergeInfo.begin());
3063 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
3064 blockMergeInfo, headerBlock,
3065 mergeBlock, continueBlock
3071 if (failed(structurizer.structurize()))
3078 <<
"//--- [cf] completed structurizing control flow ---//\n";
3091 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3092 if (fileName.empty())
3093 fileName =
"<unknown>";
3105 if (operands.size() != 3)
3106 return emitError(unknownLoc,
"OpLine must have 3 operands");
3107 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3115 if (operands.size() < 2)
3116 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3118 if (!debugInfoMap.lookup(operands[0]).empty())
3120 "duplicate debug string found for result <id> ")
3123 unsigned wordIndex = 1;
3125 if (wordIndex != operands.size())
3127 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processSamplerType(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
LogicalResult processNamedBarrierType(ArrayRef< uint32_t > operands)
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static NamedBarrierType get(MLIRContext *context)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.