MLIR 23.0.0git
Deserializer.cpp
Go to the documentation of this file.
1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 if (failed(resolveDeferredIdDecorations()))
96 return failure();
97
98 attachVCETriple();
99
100 LLVM_DEBUG(logger.startLine()
101 << "//+++-------- completed deserialization --------+++//\n");
102 return success();
103}
104
105OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
106 return std::move(module);
107}
108
109//===----------------------------------------------------------------------===//
110// Module structure
111//===----------------------------------------------------------------------===//
112
113OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
114 OpBuilder builder(context);
115 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
116 spirv::ModuleOp::build(builder, state);
117 return cast<spirv::ModuleOp>(Operation::create(state));
118}
119
120LogicalResult spirv::Deserializer::processHeader() {
121 if (binary.size() < spirv::kHeaderWordCount)
122 return emitError(unknownLoc,
123 "SPIR-V binary module must have a 5-word header");
124
125 if (binary[0] != spirv::kMagicNumber)
126 return emitError(unknownLoc, "incorrect magic number");
127
128 // Version number bytes: 0 | major number | minor number | 0
129 uint32_t majorVersion = (binary[1] << 8) >> 24;
130 uint32_t minorVersion = (binary[1] << 16) >> 24;
131 if (majorVersion == 1) {
132 switch (minorVersion) {
133#define MIN_VERSION_CASE(v) \
134 case v: \
135 version = spirv::Version::V_1_##v; \
136 break
137
145#undef MIN_VERSION_CASE
146 default:
147 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
148 << minorVersion;
149 }
150 } else {
151 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
152 << majorVersion;
153 }
154
155 // TODO: generator number, bound, schema
156 curOffset = spirv::kHeaderWordCount;
157 return success();
158}
159
160LogicalResult
161spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
162 if (operands.size() != 1)
163 return emitError(unknownLoc, "OpCapability must have one parameter");
164
165 auto cap = spirv::symbolizeCapability(operands[0]);
166 if (!cap)
167 return emitError(unknownLoc, "unknown capability: ") << operands[0];
168
169 capabilities.insert(*cap);
170 return success();
171}
172
173LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 if (words.empty()) {
175 return emitError(
176 unknownLoc,
177 "OpExtension must have a literal string for the extension name");
178 }
179
180 unsigned wordIndex = 0;
181 StringRef extName = decodeStringLiteral(words, wordIndex);
182 if (wordIndex != words.size())
183 return emitError(unknownLoc,
184 "unexpected trailing words in OpExtension instruction");
185 auto ext = spirv::symbolizeExtension(extName);
186 if (!ext)
187 return emitError(unknownLoc, "unknown extension: ") << extName;
188
189 extensions.insert(*ext);
190 return success();
191}
192
193LogicalResult
194spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
195 if (words.size() < 2) {
196 return emitError(unknownLoc,
197 "OpExtInstImport must have a result <id> and a literal "
198 "string for the extended instruction set name");
199 }
200
201 unsigned wordIndex = 1;
202 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
203 if (wordIndex != words.size()) {
204 return emitError(unknownLoc,
205 "unexpected trailing words in OpExtInstImport");
206 }
207 return success();
208}
209
210void spirv::Deserializer::attachVCETriple() {
211 (*module)->setAttr(
212 spirv::ModuleOp::getVCETripleAttrName(),
213 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
214 extensions.getArrayRef(), context));
215}
216
217LogicalResult
218spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
219 if (operands.size() != 2)
220 return emitError(unknownLoc, "OpMemoryModel must have two operands");
221
222 (*module)->setAttr(
223 module->getAddressingModelAttrName(),
224 opBuilder.getAttr<spirv::AddressingModelAttr>(
225 static_cast<spirv::AddressingModel>(operands.front())));
226
227 (*module)->setAttr(module->getMemoryModelAttrName(),
228 opBuilder.getAttr<spirv::MemoryModelAttr>(
229 static_cast<spirv::MemoryModel>(operands.back())));
230
231 return success();
232}
233
234template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
236 Location loc, OpBuilder &opBuilder,
238 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
239 if (words.size() != 4) {
240 return emitError(loc, "OpDecorate with ")
241 << decorationName << " needs a cache control integer literal and a "
242 << cacheControlKind << " cache control literal";
243 }
244 unsigned cacheLevel = words[2];
245 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
246 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
248 if (auto attrList =
249 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
250 llvm::append_range(attrs, attrList);
251 attrs.push_back(value);
252 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
253 return success();
254}
255
256LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 // TODO: This function should also be auto-generated. For now, since only a
258 // few decorations are processed/handled in a meaningful manner, going with a
259 // manual implementation.
260 if (words.size() < 2) {
261 return emitError(
262 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
263 }
264 auto decorationName =
265 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
266 if (decorationName.empty()) {
267 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
268 }
269 auto symbol = getSymbolDecoration(decorationName);
270 switch (static_cast<spirv::Decoration>(words[1])) {
271 case spirv::Decoration::FPFastMathMode:
272 if (words.size() != 3) {
273 return emitError(unknownLoc, "OpDecorate with ")
274 << decorationName << " needs a single integer literal";
275 }
276 decorations[words[0]].set(
277 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
278 static_cast<FPFastMathMode>(words[2])));
279 break;
280 case spirv::Decoration::FPRoundingMode:
281 if (words.size() != 3) {
282 return emitError(unknownLoc, "OpDecorate with ")
283 << decorationName << " needs a single integer literal";
284 }
285 decorations[words[0]].set(
286 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
287 static_cast<FPRoundingMode>(words[2])));
288 break;
289 case spirv::Decoration::DescriptorSet:
290 case spirv::Decoration::Binding:
291 case spirv::Decoration::Location:
292 case spirv::Decoration::SpecId:
293 case spirv::Decoration::Index:
294 case spirv::Decoration::Offset:
295 case spirv::Decoration::XfbBuffer:
296 case spirv::Decoration::XfbStride:
297 if (words.size() != 3) {
298 return emitError(unknownLoc, "OpDecorate with ")
299 << decorationName << " needs a single integer literal";
300 }
301 decorations[words[0]].set(
302 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
303 break;
304 case spirv::Decoration::BuiltIn:
305 if (words.size() != 3) {
306 return emitError(unknownLoc, "OpDecorate with ")
307 << decorationName << " needs a single integer literal";
308 }
309 decorations[words[0]].set(
310 symbol, opBuilder.getStringAttr(
311 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
312 break;
313 case spirv::Decoration::ArrayStride:
314 if (words.size() != 3) {
315 return emitError(unknownLoc, "OpDecorate with ")
316 << decorationName << " needs a single integer literal";
317 }
318 typeDecorations[words[0]] = words[2];
319 break;
320 case spirv::Decoration::LinkageAttributes: {
321 if (words.size() < 4) {
322 return emitError(unknownLoc, "OpDecorate with ")
323 << decorationName
324 << " needs at least 1 string and 1 integer literal";
325 }
326 // LinkageAttributes has two parameters ["linkageName", linkageType]
327 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
328 // "linkageName" is a stringliteral encoded as uint32_t,
329 // hence the size of name is variable length which results in words.size()
330 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
331 // 3 + ceildiv(strlen(name), 4).
332 unsigned wordIndex = 2;
333 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
334 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
335 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
336 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
337 StringAttr::get(context, linkageName), linkageTypeAttr);
338 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
339 break;
340 }
341 case spirv::Decoration::Aliased:
342 case spirv::Decoration::AliasedPointer:
343 case spirv::Decoration::Block:
344 case spirv::Decoration::BufferBlock:
345 case spirv::Decoration::Flat:
346 case spirv::Decoration::NonReadable:
347 case spirv::Decoration::NonWritable:
348 case spirv::Decoration::NoPerspective:
349 case spirv::Decoration::NoSignedWrap:
350 case spirv::Decoration::NoUnsignedWrap:
351 case spirv::Decoration::RelaxedPrecision:
352 case spirv::Decoration::Restrict:
353 case spirv::Decoration::RestrictPointer:
354 case spirv::Decoration::NoContraction:
355 case spirv::Decoration::Constant:
356 case spirv::Decoration::Invariant:
357 case spirv::Decoration::Patch:
358 case spirv::Decoration::Coherent:
359 if (words.size() != 2) {
360 return emitError(unknownLoc, "OpDecorate with ")
361 << decorationName << " needs a single target <id>";
362 }
363 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
364 break;
365 case spirv::Decoration::CacheControlLoadINTEL: {
366 LogicalResult res = deserializeCacheControlDecoration<
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369 "load");
370 if (failed(res))
371 return res;
372 break;
373 }
374 case spirv::Decoration::CacheControlStoreINTEL: {
375 LogicalResult res = deserializeCacheControlDecoration<
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378 "store");
379 if (failed(res))
380 return res;
381 break;
382 }
383 case spirv::Decoration::AlignmentId:
384 case spirv::Decoration::MaxByteOffsetId:
385 case spirv::Decoration::CounterBuffer:
386 if (words.size() != 3) {
387 return emitError(unknownLoc, "OpDecorateId with ")
388 << decorationName << " needs a single <id> operand";
389 }
390 pendingIdDecorations.push_back({words[0],
391 static_cast<spirv::Decoration>(words[1]),
392 words[2], unknownLoc});
393 break;
394 default:
395 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
396 }
397 return success();
398}
399
400LogicalResult spirv::Deserializer::resolveDeferredIdDecorations() {
401 for (const DeferredIdDecoration &entry : pendingIdDecorations) {
402 StringRef decorationName = stringifyDecoration(entry.decoration);
403 StringAttr symbol = getSymbolDecoration(decorationName);
404
405 // Resolve the operand <id> to a symbol name. The operand must reference a
406 // module-scope symbol op (global variable or specialization constant).
407 StringRef operandSymName;
408 if (spirv::GlobalVariableOp varOp =
409 globalVariableMap.lookup(entry.operandID))
410 operandSymName = varOp.getSymName();
411 else if (spirv::SpecConstantOp specOp =
412 specConstMap.lookup(entry.operandID))
413 operandSymName = specOp.getSymName();
414 else
415 return emitError(entry.loc, "OpDecorateId with ")
416 << decorationName << " references <id> " << entry.operandID
417 << " which is not a global variable or specialization constant";
418
419 auto symRef = FlatSymbolRefAttr::get(context, operandSymName);
420
421 // Resolve the decoration target. By the time this method runs, all
422 // instructions have been processed, so every defined <id> must appear in
423 // one of these maps; an unresolved target indicates malformed input.
424 Operation *targetOp = nullptr;
425 if (spirv::GlobalVariableOp varOp =
426 globalVariableMap.lookup(entry.targetID))
427 targetOp = varOp;
428 else if (spirv::SpecConstantOp specOp = specConstMap.lookup(entry.targetID))
429 targetOp = specOp;
430 else if (spirv::FuncOp fnOp = funcMap.lookup(entry.targetID))
431 targetOp = fnOp;
432 else if (Value v = valueMap.lookup(entry.targetID))
433 targetOp = v.getDefiningOp();
434
435 if (!targetOp)
436 return emitError(entry.loc, "OpDecorateId with ")
437 << decorationName << " references unknown target <id> "
438 << entry.targetID;
439
440 targetOp->setAttr(symbol, symRef);
441 }
442 return success();
443}
444
445LogicalResult
446spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
447 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
448 if (words.size() < 3) {
449 return emitError(unknownLoc,
450 "OpMemberDecorate must have at least 3 operands");
451 }
452
453 auto decoration = static_cast<spirv::Decoration>(words[2]);
454 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
455 return emitError(unknownLoc,
456 " missing offset specification in OpMemberDecorate with "
457 "Offset decoration");
458 }
459 ArrayRef<uint32_t> decorationOperands;
460 if (words.size() > 3) {
461 decorationOperands = words.slice(3);
462 }
463 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
464 return success();
465}
466
467LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
468 if (words.size() < 3) {
469 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
470 }
471 unsigned wordIndex = 2;
472 auto name = decodeStringLiteral(words, wordIndex);
473 if (wordIndex != words.size()) {
474 return emitError(unknownLoc,
475 "unexpected trailing words in OpMemberName instruction");
476 }
477 memberNameMap[words[0]][words[1]] = name;
478 return success();
479}
480
482 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
483 if (!decorations.contains(argID)) {
484 argAttrs[argIndex] = DictionaryAttr::get(context, {});
485 return success();
486 }
487
488 spirv::DecorationAttr foundDecorationAttr;
489 for (NamedAttribute decAttr : decorations[argID]) {
490 for (auto decoration :
491 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
492 spirv::Decoration::AliasedPointer,
493 spirv::Decoration::RestrictPointer}) {
494
495 if (decAttr.getName() !=
496 getSymbolDecoration(stringifyDecoration(decoration)))
497 continue;
498
499 if (foundDecorationAttr)
500 return emitError(unknownLoc,
501 "more than one Aliased/Restrict decorations for "
502 "function argument with result <id> ")
503 << argID;
504
505 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
506 break;
507 }
508
509 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
510 spirv::Decoration::RelaxedPrecision))) {
511 // TODO: Current implementation supports only one decoration per function
512 // parameter so RelaxedPrecision cannot be applied at the same time as,
513 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
514 // combination of decoration allowed by the spec to be supported.
515 if (foundDecorationAttr)
516 return emitError(unknownLoc, "already found a decoration for function "
517 "argument with result <id> ")
518 << argID;
519
520 foundDecorationAttr = spirv::DecorationAttr::get(
521 context, spirv::Decoration::RelaxedPrecision);
522 }
523 }
524
525 if (!foundDecorationAttr)
526 return emitError(unknownLoc, "unimplemented decoration support for "
527 "function argument with result <id> ")
528 << argID;
529
530 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
531 foundDecorationAttr);
532 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
533 return success();
534}
535
536LogicalResult
538 if (curFunction) {
539 return emitError(unknownLoc, "found function inside function");
540 }
541
542 // Get the result type
543 if (operands.size() != 4) {
544 return emitError(unknownLoc, "OpFunction must have 4 parameters");
545 }
546 Type resultType = getType(operands[0]);
547 if (!resultType) {
548 return emitError(unknownLoc, "undefined result type from <id> ")
549 << operands[0];
550 }
551
552 uint32_t fnID = operands[1];
553 if (funcMap.count(fnID)) {
554 return emitError(unknownLoc, "duplicate function definition/declaration");
555 }
556
557 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
558 if (!fnControl) {
559 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
560 }
561
562 Type fnType = getType(operands[3]);
563 if (!fnType || !isa<FunctionType>(fnType)) {
564 return emitError(unknownLoc, "unknown function type from <id> ")
565 << operands[3];
566 }
567 auto functionType = cast<FunctionType>(fnType);
568
569 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
570 (functionType.getNumResults() == 1 &&
571 functionType.getResult(0) != resultType)) {
572 return emitError(unknownLoc, "mismatch in function type ")
573 << functionType << " and return type " << resultType << " specified";
574 }
575
576 std::string fnName = getFunctionSymbol(fnID);
577 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
578 functionType, fnControl.value());
579 // Processing other function attributes.
580 if (decorations.count(fnID)) {
581 for (auto attr : decorations[fnID].getAttrs()) {
582 funcOp->setAttr(attr.getName(), attr.getValue());
583 }
584 }
585 curFunction = funcMap[fnID] = funcOp;
586 auto *entryBlock = funcOp.addEntryBlock();
587 LLVM_DEBUG({
588 logger.startLine()
589 << "//===-------------------------------------------===//\n";
590 logger.startLine() << "[fn] name: " << fnName << "\n";
591 logger.startLine() << "[fn] type: " << fnType << "\n";
592 logger.startLine() << "[fn] ID: " << fnID << "\n";
593 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
594 logger.indent();
595 });
596
597 SmallVector<Attribute> argAttrs;
598 argAttrs.resize(functionType.getNumInputs());
599
600 // Parse the op argument instructions
601 if (functionType.getNumInputs()) {
602 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
603 auto argType = functionType.getInput(i);
604 spirv::Opcode opcode = spirv::Opcode::OpNop;
605 ArrayRef<uint32_t> operands;
606 if (failed(sliceInstruction(opcode, operands,
607 spirv::Opcode::OpFunctionParameter))) {
608 return failure();
609 }
610 if (opcode != spirv::Opcode::OpFunctionParameter) {
611 return emitError(
612 unknownLoc,
613 "missing OpFunctionParameter instruction for argument ")
614 << i;
615 }
616 if (operands.size() != 2) {
617 return emitError(
618 unknownLoc,
619 "expected result type and result <id> for OpFunctionParameter");
620 }
621 auto argDefinedType = getType(operands[0]);
622 if (!argDefinedType || argDefinedType != argType) {
623 return emitError(unknownLoc,
624 "mismatch in argument type between function type "
625 "definition ")
626 << functionType << " and argument type definition "
627 << argDefinedType << " at argument " << i;
628 }
629 if (getValue(operands[1])) {
630 return emitError(unknownLoc, "duplicate definition of result <id> ")
631 << operands[1];
632 }
633 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
634 return failure();
635 }
636
637 auto argValue = funcOp.getArgument(i);
638 valueMap[operands[1]] = argValue;
639 }
640 }
641
642 if (llvm::any_of(argAttrs, [](Attribute attr) {
643 auto argAttr = cast<DictionaryAttr>(attr);
644 return !argAttr.empty();
645 }))
646 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
647
648 // entryBlock is needed to access the arguments, Once that is done, we can
649 // erase the block for functions with 'Import' LinkageAttributes, since these
650 // are essentially function declarations, so they have no body.
651 auto linkageAttr = funcOp.getLinkageAttributes();
652 auto hasImportLinkage =
653 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
654 spirv::LinkageType::Import);
655 if (hasImportLinkage)
656 funcOp.eraseBody();
657
658 // RAII guard to reset the insertion point to the module's region after
659 // deserializing the body of this function.
660 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
661
662 spirv::Opcode opcode = spirv::Opcode::OpNop;
663 ArrayRef<uint32_t> instOperands;
664
665 // Special handling for the entry block. We need to make sure it starts with
666 // an OpLabel instruction. The entry block takes the same parameters as the
667 // function. All other blocks do not take any parameter. We have already
668 // created the entry block, here we need to register it to the correct label
669 // <id>.
670 if (failed(sliceInstruction(opcode, instOperands,
671 spirv::Opcode::OpFunctionEnd))) {
672 return failure();
673 }
674 if (opcode == spirv::Opcode::OpFunctionEnd) {
675 return processFunctionEnd(instOperands);
676 }
677 if (opcode != spirv::Opcode::OpLabel) {
678 return emitError(unknownLoc, "a basic block must start with OpLabel");
679 }
680 if (instOperands.size() != 1) {
681 return emitError(unknownLoc, "OpLabel should only have result <id>");
682 }
683 blockMap[instOperands[0]] = entryBlock;
684 if (failed(processLabel(instOperands))) {
685 return failure();
686 }
687
688 // Then process all the other instructions in the function until we hit
689 // OpFunctionEnd.
690 while (succeeded(sliceInstruction(opcode, instOperands,
691 spirv::Opcode::OpFunctionEnd)) &&
692 opcode != spirv::Opcode::OpFunctionEnd) {
693 if (failed(processInstruction(opcode, instOperands))) {
694 return failure();
695 }
696 }
697 if (opcode != spirv::Opcode::OpFunctionEnd) {
698 return failure();
699 }
700
701 return processFunctionEnd(instOperands);
702}
703
704LogicalResult
706 // Process OpFunctionEnd.
707 if (!operands.empty()) {
708 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
709 }
710
711 // Wire up block arguments from OpPhi instructions.
712 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
713 // ops.
714 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
715 return failure();
716 }
717
718 curBlock = nullptr;
719 curFunction = std::nullopt;
720
721 LLVM_DEBUG({
722 logger.unindent();
723 logger.startLine()
724 << "//===-------------------------------------------===//\n";
725 });
726 return success();
727}
728
729LogicalResult
731 if (operands.size() < 2) {
732 return emitError(unknownLoc,
733 "missing graph defintion in OpGraphEntryPointARM");
734 }
735
736 unsigned wordIndex = 0;
737 uint32_t graphID = operands[wordIndex++];
738 if (!graphMap.contains(graphID)) {
739 return emitError(unknownLoc,
740 "missing graph definition/declaration with id ")
741 << graphID;
742 }
743
744 spirv::GraphARMOp graphARM = graphMap[graphID];
745 StringRef name = decodeStringLiteral(operands, wordIndex);
746 graphARM.setSymName(name);
747 graphARM.setEntryPoint(true);
748
750 for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
751 if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
752 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
753 } else {
754 return emitError(unknownLoc, "undefined result <id> ")
755 << operands[wordIndex] << " while decoding OpGraphEntryPoint";
756 }
757 }
758
759 // RAII guard to reset the insertion point to previous value when done.
760 OpBuilder::InsertionGuard insertionGuard(opBuilder);
761 opBuilder.setInsertionPoint(graphARM);
762 spirv::GraphEntryPointARMOp::create(
763 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
764 opBuilder.getArrayAttr(interface));
765
766 return success();
767}
768
769LogicalResult
771 if (curGraph) {
772 return emitError(unknownLoc, "found graph inside graph");
773 }
774 // Get the result type.
775 if (operands.size() < 2) {
776 return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
777 }
778
779 Type type = getType(operands[0]);
780 if (!type || !isa<GraphType>(type)) {
781 return emitError(unknownLoc, "unknown graph type from <id> ")
782 << operands[0];
783 }
784 auto graphType = cast<GraphType>(type);
785 if (graphType.getNumResults() <= 0) {
786 return emitError(unknownLoc, "expected at least one result");
787 }
788
789 uint32_t graphID = operands[1];
790 if (graphMap.count(graphID)) {
791 return emitError(unknownLoc, "duplicate graph definition/declaration");
792 }
793
794 std::string graphName = getGraphSymbol(graphID);
795 auto graphOp =
796 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
797 curGraph = graphMap[graphID] = graphOp;
798 Block *entryBlock = graphOp.addEntryBlock();
799 LLVM_DEBUG({
800 logger.startLine()
801 << "//===-------------------------------------------===//\n";
802 logger.startLine() << "[graph] name: " << graphName << "\n";
803 logger.startLine() << "[graph] type: " << graphType << "\n";
804 logger.startLine() << "[graph] ID: " << graphID << "\n";
805 logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
806 logger.indent();
807 });
808
809 // Parse the op argument instructions.
810 for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
811 spirv::Opcode opcode;
812 ArrayRef<uint32_t> operands;
813 if (failed(sliceInstruction(opcode, operands,
814 spirv::Opcode::OpGraphInputARM))) {
815 return failure();
816 }
817 if (operands.size() != 3) {
818 return emitError(unknownLoc, "expected result type, result <id> and "
819 "input index for OpGraphInputARM");
820 }
821
822 Type argDefinedType = getType(operands[0]);
823 if (!argDefinedType) {
824 return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
825 }
826
827 if (argDefinedType != argType) {
828 return emitError(unknownLoc,
829 "mismatch in argument type between graph type "
830 "definition ")
831 << graphType << " and argument type definition " << argDefinedType
832 << " at argument " << index;
833 }
834 if (getValue(operands[1])) {
835 return emitError(unknownLoc, "duplicate definition of result <id> ")
836 << operands[1];
837 }
838
839 IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
840 if (!inputIndexAttr) {
841 return emitError(unknownLoc,
842 "unable to read inputIndex value from constant op ")
843 << operands[2];
844 }
845 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
846 valueMap[operands[1]] = argValue;
847 }
848
849 graphOutputs.resize(graphType.getNumResults());
850
851 // RAII guard to reset the insertion point to the module's region after
852 // deserializing the body of this function.
853 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
854
855 blockMap[graphID] = entryBlock;
856 if (failed(createGraphBlock(graphID))) {
857 return failure();
858 }
859
860 // Process all the instructions in the graph until and including
861 // OpGraphEndARM.
862 spirv::Opcode opcode;
863 ArrayRef<uint32_t> instOperands;
864 do {
865 if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
866 return failure();
867 }
868
869 if (failed(processInstruction(opcode, instOperands))) {
870 return failure();
871 }
872 } while (opcode != spirv::Opcode::OpGraphEndARM);
873
874 return success();
875}
876
877LogicalResult
879 if (operands.size() != 2) {
880 return emitError(
881 unknownLoc,
882 "expected value id and output index for OpGraphSetOutputARM");
883 }
884
885 uint32_t id = operands[0];
886 Value value = getValue(id);
887 if (!value) {
888 return emitError(unknownLoc, "could not find result <id> ") << id;
889 }
890
891 IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
892 if (!outputIndexAttr) {
893 return emitError(unknownLoc,
894 "unable to read outputIndex value from constant op ")
895 << operands[1];
896 }
897 graphOutputs[outputIndexAttr.getInt()] = value;
898 return success();
899}
900
901LogicalResult
903 // Create GraphOutputsARM instruction.
904 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
905
906 // Process OpGraphEndARM.
907 if (!operands.empty()) {
908 return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
909 }
910
911 curBlock = nullptr;
912 curGraph = std::nullopt;
913 graphOutputs.clear();
914
915 LLVM_DEBUG({
916 logger.unindent();
917 logger.startLine()
918 << "//===-------------------------------------------===//\n";
919 });
920 return success();
921}
922
923std::optional<std::pair<Attribute, Type>>
925 auto constIt = constantMap.find(id);
926 if (constIt == constantMap.end())
927 return std::nullopt;
928 return constIt->getSecond();
929}
930
931std::optional<std::pair<Attribute, Type>>
933 if (auto it = constantCompositeReplicateMap.find(id);
934 it != constantCompositeReplicateMap.end())
935 return it->second;
936 return std::nullopt;
937}
938
939std::optional<spirv::SpecConstOperationMaterializationInfo>
941 auto constIt = specConstOperationMap.find(id);
942 if (constIt == specConstOperationMap.end())
943 return std::nullopt;
944 return constIt->getSecond();
945}
946
948 auto funcName = nameMap.lookup(id).str();
949 if (funcName.empty()) {
950 funcName = "spirv_fn_" + std::to_string(id);
951 }
952 return funcName;
953}
954
955std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
956 std::string graphName = nameMap.lookup(id).str();
957 if (graphName.empty()) {
958 graphName = "spirv_graph_" + std::to_string(id);
959 }
960 return graphName;
961}
962
964 auto constName = nameMap.lookup(id).str();
965 if (constName.empty()) {
966 constName = "spirv_spec_const_" + std::to_string(id);
967 }
968 return constName;
969}
970
971spirv::SpecConstantOp
973 TypedAttr defaultValue) {
974 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
975 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
976 defaultValue);
977 if (decorations.count(resultID)) {
978 for (auto attr : decorations[resultID].getAttrs())
979 op->setAttr(attr.getName(), attr.getValue());
980 }
981 specConstMap[resultID] = op;
982 return op;
983}
984
985std::optional<spirv::GraphConstantARMOpMaterializationInfo>
987 auto graphConstIt = graphConstantMap.find(id);
988 if (graphConstIt == graphConstantMap.end())
989 return std::nullopt;
990 return graphConstIt->getSecond();
991}
992
993LogicalResult
995 unsigned wordIndex = 0;
996 if (operands.size() < 3) {
997 return emitError(
998 unknownLoc,
999 "OpVariable needs at least 3 operands, type, <id> and storage class");
1000 }
1001
1002 // Result Type.
1003 auto type = getType(operands[wordIndex]);
1004 if (!type) {
1005 return emitError(unknownLoc, "unknown result type <id> : ")
1006 << operands[wordIndex];
1007 }
1008 auto ptrType = dyn_cast<spirv::PointerType>(type);
1009 if (!ptrType) {
1010 return emitError(unknownLoc,
1011 "expected a result type <id> to be a spirv.ptr, found : ")
1012 << type;
1013 }
1014 wordIndex++;
1015
1016 // Result <id>.
1017 auto variableID = operands[wordIndex];
1018 auto variableName = nameMap.lookup(variableID).str();
1019 if (variableName.empty()) {
1020 variableName = "spirv_var_" + std::to_string(variableID);
1021 }
1022 wordIndex++;
1023
1024 // Storage class.
1025 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
1026 if (ptrType.getStorageClass() != storageClass) {
1027 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
1028 << type << " and that specified in OpVariable instruction : "
1029 << stringifyStorageClass(storageClass);
1030 }
1031 wordIndex++;
1032
1033 // Initializer.
1034 FlatSymbolRefAttr initializer = nullptr;
1035
1036 if (wordIndex < operands.size()) {
1037 Operation *op = nullptr;
1038
1039 if (auto initOp = getGlobalVariable(operands[wordIndex]))
1040 op = initOp;
1041 else if (auto initOp = getSpecConstant(operands[wordIndex]))
1042 op = initOp;
1043 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
1044 op = initOp;
1045 else
1046 return emitError(unknownLoc, "unknown <id> ")
1047 << operands[wordIndex] << "used as initializer";
1048
1049 initializer = SymbolRefAttr::get(op);
1050 wordIndex++;
1051 }
1052 if (wordIndex != operands.size()) {
1053 return emitError(unknownLoc,
1054 "found more operands than expected when deserializing "
1055 "OpVariable instruction, only ")
1056 << wordIndex << " of " << operands.size() << " processed";
1057 }
1058 auto loc = createFileLineColLoc(opBuilder);
1059 auto varOp = spirv::GlobalVariableOp::create(
1060 opBuilder, loc, TypeAttr::get(type),
1061 opBuilder.getStringAttr(variableName), initializer);
1062
1063 // Decorations.
1064 if (decorations.count(variableID)) {
1065 for (auto attr : decorations[variableID].getAttrs())
1066 varOp->setAttr(attr.getName(), attr.getValue());
1067 }
1068 globalVariableMap[variableID] = varOp;
1069 return success();
1070}
1071
1072IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1073 auto constInfo = getConstant(id);
1074 if (!constInfo) {
1075 return nullptr;
1076 }
1077 return dyn_cast<IntegerAttr>(constInfo->first);
1078}
1079
1080LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1081 if (operands.size() < 2) {
1082 return emitError(unknownLoc, "OpName needs at least 2 operands");
1083 }
1084
1085 unsigned wordIndex = 1;
1086 StringRef name = decodeStringLiteral(operands, wordIndex);
1087 if (wordIndex != operands.size()) {
1088 return emitError(unknownLoc,
1089 "unexpected trailing words in OpName instruction");
1090 }
1091
1092 // In SPIRV it's valid for multiple OpName instructions to refer to the same
1093 // <id>. Use a "last one wins" approach to resolve such cases.
1094 nameMap.emplace_or_assign(operands[0], name);
1095
1096 return success();
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// Type
1101//===----------------------------------------------------------------------===//
1102
1103LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1104 ArrayRef<uint32_t> operands) {
1105 if (operands.empty()) {
1106 return emitError(unknownLoc, "type instruction with opcode ")
1107 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1108 }
1109
1110 /// TODO: Types might be forward declared in some instructions and need to be
1111 /// handled appropriately.
1112 if (typeMap.count(operands[0])) {
1113 return emitError(unknownLoc, "duplicate definition for result <id> ")
1114 << operands[0];
1115 }
1116
1117 switch (opcode) {
1118 case spirv::Opcode::OpTypeVoid:
1119 if (operands.size() != 1)
1120 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1121 typeMap[operands[0]] = opBuilder.getNoneType();
1122 break;
1123 case spirv::Opcode::OpTypeBool:
1124 if (operands.size() != 1)
1125 return emitError(unknownLoc, "OpTypeBool must have no parameters");
1126 typeMap[operands[0]] = opBuilder.getI1Type();
1127 break;
1128 case spirv::Opcode::OpTypeInt: {
1129 if (operands.size() != 3)
1130 return emitError(
1131 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1132
1133 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1134 // to preserve or validate.
1135 // 0 indicates unsigned, or no signedness semantics
1136 // 1 indicates signed semantics."
1137 //
1138 // So we cannot differentiate signless and unsigned integers; always use
1139 // signless semantics for such cases.
1140 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1141 : IntegerType::SignednessSemantics::Signless;
1142 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1143 } break;
1144 case spirv::Opcode::OpTypeFloat: {
1145 if (operands.size() != 2 && operands.size() != 3)
1146 return emitError(unknownLoc,
1147 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1148 "or 3 operands (type, bitwidth, encoding), but got ")
1149 << operands.size();
1150 uint32_t bitWidth = operands[1];
1151
1152 Type floatTy;
1153 if (operands.size() == 2) {
1154 switch (bitWidth) {
1155 case 16:
1156 floatTy = opBuilder.getF16Type();
1157 break;
1158 case 32:
1159 floatTy = opBuilder.getF32Type();
1160 break;
1161 case 64:
1162 floatTy = opBuilder.getF64Type();
1163 break;
1164 default:
1165 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1166 << bitWidth;
1167 }
1168 }
1169
1170 if (operands.size() == 3) {
1171 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1172 bitWidth == 16)
1173 floatTy = opBuilder.getBF16Type();
1174 else if (spirv::FPEncoding(operands[2]) ==
1175 spirv::FPEncoding::Float8E4M3EXT &&
1176 bitWidth == 8)
1177 floatTy = opBuilder.getF8E4M3FNType();
1178 else if (spirv::FPEncoding(operands[2]) ==
1179 spirv::FPEncoding::Float8E5M2EXT &&
1180 bitWidth == 8)
1181 floatTy = opBuilder.getF8E5M2Type();
1182 else
1183 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1184 << operands[2] << " and bitWidth " << bitWidth;
1185 }
1186
1187 typeMap[operands[0]] = floatTy;
1188 } break;
1189 case spirv::Opcode::OpTypeVector: {
1190 if (operands.size() != 3) {
1191 return emitError(
1192 unknownLoc,
1193 "OpTypeVector must have element type and count parameters");
1194 }
1195 Type elementTy = getType(operands[1]);
1196 if (!elementTy) {
1197 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1198 << operands[1];
1199 }
1200 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1201 } break;
1202 case spirv::Opcode::OpTypePointer: {
1203 return processOpTypePointer(operands);
1204 } break;
1205 case spirv::Opcode::OpTypeArray:
1206 return processArrayType(operands);
1207 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1208 return processCooperativeMatrixTypeKHR(operands);
1209 case spirv::Opcode::OpTypeFunction:
1210 return processFunctionType(operands);
1211 case spirv::Opcode::OpTypeImage:
1212 return processImageType(operands);
1213 case spirv::Opcode::OpTypeSampler:
1214 return processSamplerType(operands);
1215 case spirv::Opcode::OpTypeNamedBarrier:
1216 return processNamedBarrierType(operands);
1217 case spirv::Opcode::OpTypeSampledImage:
1218 return processSampledImageType(operands);
1219 case spirv::Opcode::OpTypeRuntimeArray:
1220 return processRuntimeArrayType(operands);
1221 case spirv::Opcode::OpTypeStruct:
1222 return processStructType(operands);
1223 case spirv::Opcode::OpTypeMatrix:
1224 return processMatrixType(operands);
1225 case spirv::Opcode::OpTypeTensorARM:
1226 return processTensorARMType(operands);
1227 case spirv::Opcode::OpTypeGraphARM:
1228 return processGraphTypeARM(operands);
1229 default:
1230 return emitError(unknownLoc, "unhandled type instruction");
1231 }
1232 return success();
1233}
1234
1235LogicalResult
1237 if (operands.size() != 3)
1238 return emitError(unknownLoc, "OpTypePointer must have two parameters");
1239
1240 auto pointeeType = getType(operands[2]);
1241 if (!pointeeType)
1242 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1243 << operands[2];
1244
1245 uint32_t typePointerID = operands[0];
1246 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1247 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1248
1249 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1250 deferredStructIt != std::end(deferredStructTypesInfos);) {
1251 for (auto *unresolvedMemberIt =
1252 std::begin(deferredStructIt->unresolvedMemberTypes);
1253 unresolvedMemberIt !=
1254 std::end(deferredStructIt->unresolvedMemberTypes);) {
1255 if (unresolvedMemberIt->first == typePointerID) {
1256 // The newly constructed pointer type can resolve one of the
1257 // deferred struct type members; update the memberTypes list and
1258 // clean the unresolvedMemberTypes list accordingly.
1259 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1260 typeMap[typePointerID];
1261 unresolvedMemberIt =
1262 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1263 } else {
1264 ++unresolvedMemberIt;
1265 }
1266 }
1267
1268 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1269 // All deferred struct type members are now resolved, set the struct body.
1270 auto structType = deferredStructIt->deferredStructType;
1271
1272 assert(structType && "expected a spirv::StructType");
1273 assert(structType.isIdentified() && "expected an indentified struct");
1274
1275 if (failed(structType.trySetBody(
1276 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1277 deferredStructIt->memberDecorationsInfo,
1278 deferredStructIt->structDecorationsInfo)))
1279 return failure();
1280
1281 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1282 } else {
1283 ++deferredStructIt;
1284 }
1285 }
1286
1287 return success();
1288}
1289
1290LogicalResult
1292 if (operands.size() != 3) {
1293 return emitError(unknownLoc,
1294 "OpTypeArray must have element type and count parameters");
1295 }
1296
1297 Type elementTy = getType(operands[1]);
1298 if (!elementTy) {
1299 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1300 << operands[1];
1301 }
1302
1303 unsigned count = 0;
1304 // TODO: The count can also come frome a specialization constant.
1305 auto countInfo = getConstant(operands[2]);
1306 if (!countInfo) {
1307 return emitError(unknownLoc, "OpTypeArray count <id> ")
1308 << operands[2] << "can only come from normal constant right now";
1309 }
1310
1311 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1312 count = intVal.getValue().getZExtValue();
1313 } else {
1314 return emitError(unknownLoc, "OpTypeArray count must come from a "
1315 "scalar integer constant instruction");
1316 }
1317
1318 typeMap[operands[0]] = spirv::ArrayType::get(
1319 elementTy, count, typeDecorations.lookup(operands[0]));
1320 return success();
1321}
1322
1323LogicalResult
1325 assert(!operands.empty() && "No operands for processing function type");
1326 if (operands.size() == 1) {
1327 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1328 }
1329 auto returnType = getType(operands[1]);
1330 if (!returnType) {
1331 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1332 }
1333 SmallVector<Type, 1> argTypes;
1334 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1335 auto ty = getType(operands[i]);
1336 if (!ty) {
1337 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1338 }
1339 argTypes.push_back(ty);
1340 }
1341 ArrayRef<Type> returnTypes;
1342 if (!isVoidType(returnType)) {
1343 returnTypes = llvm::ArrayRef(returnType);
1344 }
1345 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1346 return success();
1347}
1348
1350 ArrayRef<uint32_t> operands) {
1351 if (operands.size() != 6) {
1352 return emitError(unknownLoc,
1353 "OpTypeCooperativeMatrixKHR must have element type, "
1354 "scope, row and column parameters, and use");
1355 }
1356
1357 Type elementTy = getType(operands[1]);
1358 if (!elementTy) {
1359 return emitError(unknownLoc,
1360 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1361 << operands[1];
1362 }
1363
1364 std::optional<spirv::Scope> scope =
1365 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1366 if (!scope) {
1367 return emitError(
1368 unknownLoc,
1369 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1370 << operands[2];
1371 }
1372
1373 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1374 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1375 IntegerAttr useAttr = getConstantInt(operands[5]);
1376
1377 if (!rowsAttr)
1378 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1379 "undefined constant <id> ")
1380 << operands[3];
1381
1382 if (!columnsAttr)
1383 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1384 "references undefined constant <id> ")
1385 << operands[4];
1386
1387 if (!useAttr)
1388 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1389 "undefined constant <id> ")
1390 << operands[5];
1391
1392 unsigned rows = rowsAttr.getInt();
1393 unsigned columns = columnsAttr.getInt();
1394
1395 std::optional<spirv::CooperativeMatrixUseKHR> use =
1396 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1397 if (!use) {
1398 return emitError(
1399 unknownLoc,
1400 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1401 << operands[5];
1402 }
1403
1404 typeMap[operands[0]] =
1405 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1406 return success();
1407}
1408
1409LogicalResult
1411 if (operands.size() != 2) {
1412 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1413 }
1414 Type memberType = getType(operands[1]);
1415 if (!memberType) {
1416 return emitError(unknownLoc,
1417 "OpTypeRuntimeArray references undefined <id> ")
1418 << operands[1];
1419 }
1420 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1421 memberType, typeDecorations.lookup(operands[0]));
1422 return success();
1423}
1424
1425LogicalResult
1427 // TODO: Find a way to handle identified structs when debug info is stripped.
1428
1429 if (operands.empty()) {
1430 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1431 }
1432
1433 if (operands.size() == 1) {
1434 // Handle empty struct.
1435 typeMap[operands[0]] =
1436 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1437 return success();
1438 }
1439
1440 // First element is operand ID, second element is member index in the struct.
1441 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1442 SmallVector<Type, 4> memberTypes;
1443
1444 for (auto op : llvm::drop_begin(operands, 1)) {
1445 Type memberType = getType(op);
1446 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1447
1448 if (!memberType && !typeForwardPtr)
1449 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1450 << op;
1451
1452 if (!memberType)
1453 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1454
1455 memberTypes.push_back(memberType);
1456 }
1457
1460 if (memberDecorationMap.count(operands[0])) {
1461 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1462 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1463 if (allMemberDecorations.count(memberIndex)) {
1464 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1465 // Check for offset.
1466 if (memberDecoration.first == spirv::Decoration::Offset) {
1467 // If offset info is empty, resize to the number of members;
1468 if (offsetInfo.empty()) {
1469 offsetInfo.resize(memberTypes.size());
1470 }
1471 offsetInfo[memberIndex] = memberDecoration.second[0];
1472 } else {
1473 auto intType = mlir::IntegerType::get(context, 32);
1474 if (!memberDecoration.second.empty()) {
1475 memberDecorationsInfo.emplace_back(
1476 memberIndex, memberDecoration.first,
1477 IntegerAttr::get(intType, memberDecoration.second[0]));
1478 } else {
1479 memberDecorationsInfo.emplace_back(
1480 memberIndex, memberDecoration.first, UnitAttr::get(context));
1481 }
1482 }
1483 }
1484 }
1485 }
1486 }
1487
1489 if (decorations.count(operands[0])) {
1490 NamedAttrList &allDecorations = decorations[operands[0]];
1491 for (NamedAttribute &decorationAttr : allDecorations) {
1492 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1493 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1494 assert(decoration.has_value());
1495 structDecorationsInfo.emplace_back(decoration.value(),
1496 decorationAttr.getValue());
1497 }
1498 }
1499
1500 uint32_t structID = operands[0];
1501 std::string structIdentifier = nameMap.lookup(structID).str();
1502
1503 if (structIdentifier.empty()) {
1504 assert(unresolvedMemberTypes.empty() &&
1505 "didn't expect unresolved member types");
1506 typeMap[structID] = spirv::StructType::get(
1507 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1508 } else {
1509 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1510 typeMap[structID] = structTy;
1511
1512 if (!unresolvedMemberTypes.empty())
1513 deferredStructTypesInfos.push_back(
1514 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1515 memberDecorationsInfo, structDecorationsInfo});
1516 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1517 memberDecorationsInfo,
1518 structDecorationsInfo)))
1519 return failure();
1520 }
1521
1522 // TODO: Update StructType to have member name as attribute as
1523 // well.
1524 return success();
1525}
1526
1527LogicalResult
1529 if (operands.size() != 3) {
1530 // Three operands are needed: result_id, column_type, and column_count
1531 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1532 " (result_id, column_type, and column_count)");
1533 }
1534 // Matrix columns must be of vector type
1535 Type elementTy = getType(operands[1]);
1536 if (!elementTy) {
1537 return emitError(unknownLoc,
1538 "OpTypeMatrix references undefined column type.")
1539 << operands[1];
1540 }
1541
1542 uint32_t colsCount = operands[2];
1543 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1544 return success();
1545}
1546
1547LogicalResult
1549 unsigned size = operands.size();
1550 if (size < 2 || size > 4)
1551 return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1552 "(result_id, element_type, (rank), (shape)) ")
1553 << size;
1554
1555 Type elementTy = getType(operands[1]);
1556 if (!elementTy)
1557 return emitError(unknownLoc,
1558 "OpTypeTensorARM references undefined element type ")
1559 << operands[1];
1560
1561 if (size == 2) {
1562 typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1563 return success();
1564 }
1565
1566 IntegerAttr rankAttr = getConstantInt(operands[2]);
1567 if (!rankAttr)
1568 return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1569 "scalar integer constant instruction");
1570 unsigned rank = rankAttr.getValue().getZExtValue();
1571 if (size == 3) {
1572 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1573 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1574 return success();
1575 }
1576
1577 std::optional<std::pair<Attribute, Type>> shapeInfo =
1578 getConstant(operands[3]);
1579 if (!shapeInfo)
1580 return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1581 "constant instruction of type OpTypeArray");
1582
1583 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1585 for (auto dimAttr : shapeArrayAttr.getValue()) {
1586 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1587 if (!dimIntAttr)
1588 return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1589 "dimension size");
1590 shape.push_back(dimIntAttr.getValue().getSExtValue());
1591 }
1592 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1593 return success();
1594}
1595
1596LogicalResult
1598 unsigned size = operands.size();
1599 if (size < 2) {
1600 return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1601 "(result_id, num_inputs, (inout0_type, "
1602 "inout1_type, ...))")
1603 << size;
1604 }
1605 uint32_t numInputs = operands[1];
1606 SmallVector<Type, 1> argTypes;
1607 SmallVector<Type, 1> returnTypes;
1608 for (unsigned i = 2; i < size; ++i) {
1609 Type inOutTy = getType(operands[i]);
1610 if (!inOutTy) {
1611 return emitError(unknownLoc,
1612 "OpTypeGraphARM references undefined element type.")
1613 << operands[i];
1614 }
1615 if (i - 2 >= numInputs) {
1616 returnTypes.push_back(inOutTy);
1617 } else {
1618 argTypes.push_back(inOutTy);
1619 }
1620 }
1621 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1622 return success();
1623}
1624
1625LogicalResult
1627 if (operands.size() != 2)
1628 return emitError(unknownLoc,
1629 "OpTypeForwardPointer instruction must have two operands");
1630
1631 typeForwardPointerIDs.insert(operands[0]);
1632 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1633 // instruction that defines the actual type.
1634
1635 return success();
1636}
1637
1638LogicalResult
1640 // TODO: Add support for Access Qualifier.
1641 if (operands.size() != 8)
1642 return emitError(
1643 unknownLoc,
1644 "OpTypeImage with non-eight operands are not supported yet");
1645
1646 Type elementTy = getType(operands[1]);
1647 if (!elementTy)
1648 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1649 << operands[1];
1650
1651 auto dim = spirv::symbolizeDim(operands[2]);
1652 if (!dim)
1653 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1654 << operands[2];
1655
1656 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1657 if (!depthInfo)
1658 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1659 << operands[3];
1660
1661 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1662 if (!arrayedInfo)
1663 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1664 << operands[4];
1665
1666 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1667 if (!samplingInfo)
1668 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1669
1670 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1671 if (!samplerUseInfo)
1672 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1673 << operands[6];
1674
1675 auto format = spirv::symbolizeImageFormat(operands[7]);
1676 if (!format)
1677 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1678 << operands[7];
1679
1680 typeMap[operands[0]] = spirv::ImageType::get(
1681 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1682 samplingInfo.value(), samplerUseInfo.value(), format.value());
1683 return success();
1684}
1685
1686LogicalResult
1688 if (operands.size() != 2)
1689 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1690
1691 Type elementTy = getType(operands[1]);
1692 if (!elementTy)
1693 return emitError(unknownLoc,
1694 "OpTypeSampledImage references undefined <id>: ")
1695 << operands[1];
1696
1697 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1698 return success();
1699}
1700
1701LogicalResult
1703 if (operands.size() != 1)
1704 return emitError(unknownLoc, "OpTypeSampler must have no parameters");
1705
1706 typeMap[operands[0]] = spirv::SamplerType::get(context);
1707 return success();
1708}
1709
1710LogicalResult
1712 if (operands.size() != 1)
1713 return emitError(unknownLoc, "OpTypeNamedBarrier must have no parameters");
1714
1715 typeMap[operands[0]] = spirv::NamedBarrierType::get(context);
1716 return success();
1717}
1718
1719//===----------------------------------------------------------------------===//
1720// Constant
1721//===----------------------------------------------------------------------===//
1722
1724 bool isSpec) {
1725 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1726
1727 if (operands.size() < 2) {
1728 return emitError(unknownLoc)
1729 << opname << " must have type <id> and result <id>";
1730 }
1731 if (operands.size() < 3) {
1732 return emitError(unknownLoc)
1733 << opname << " must have at least 1 more parameter";
1734 }
1735
1736 Type resultType = getType(operands[0]);
1737 if (!resultType) {
1738 return emitError(unknownLoc, "undefined result type from <id> ")
1739 << operands[0];
1740 }
1741
1742 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1743 if (bitwidth == 64) {
1744 if (operands.size() == 4) {
1745 return success();
1746 }
1747 return emitError(unknownLoc)
1748 << opname << " should have 2 parameters for 64-bit values";
1749 }
1750 if (bitwidth <= 32) {
1751 if (operands.size() == 3) {
1752 return success();
1753 }
1754
1755 return emitError(unknownLoc)
1756 << opname
1757 << " should have 1 parameter for values with no more than 32 bits";
1758 }
1759 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1760 << bitwidth;
1761 };
1762
1763 auto resultID = operands[1];
1764
1765 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1766 auto bitwidth = intType.getWidth();
1767 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1768 return failure();
1769 }
1770
1771 APInt value;
1772 if (bitwidth == 64) {
1773 // 64-bit integers are represented with two SPIR-V words. According to
1774 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1775 // literal’s low-order words appear first."
1776 struct DoubleWord {
1777 uint32_t word1;
1778 uint32_t word2;
1779 } words = {operands[2], operands[3]};
1780 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1781 } else if (bitwidth <= 32) {
1782 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1783 /*implicitTrunc=*/true);
1784 }
1785
1786 auto attr = opBuilder.getIntegerAttr(intType, value);
1787
1788 if (isSpec) {
1789 createSpecConstant(unknownLoc, resultID, attr);
1790 } else {
1791 // For normal constants, we just record the attribute (and its type) for
1792 // later materialization at use sites.
1793 constantMap.try_emplace(resultID, attr, intType);
1794 }
1795
1796 return success();
1797 }
1798
1799 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1800 auto bitwidth = floatType.getWidth();
1801 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1802 return failure();
1803 }
1804
1805 APFloat value(0.f);
1806 if (floatType.isF64()) {
1807 // Double values are represented with two SPIR-V words. According to
1808 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1809 // literal’s low-order words appear first."
1810 struct DoubleWord {
1811 uint32_t word1;
1812 uint32_t word2;
1813 } words = {operands[2], operands[3]};
1814 value = APFloat(llvm::bit_cast<double>(words));
1815 } else if (floatType.isF32()) {
1816 value = APFloat(llvm::bit_cast<float>(operands[2]));
1817 } else if (floatType.isF16()) {
1818 APInt data(16, operands[2]);
1819 value = APFloat(APFloat::IEEEhalf(), data);
1820 } else if (floatType.isBF16()) {
1821 APInt data(16, operands[2]);
1822 value = APFloat(APFloat::BFloat(), data);
1823 } else if (floatType.isF8E4M3FN()) {
1824 APInt data(8, operands[2]);
1825 value = APFloat(APFloat::Float8E4M3FN(), data);
1826 } else if (floatType.isF8E5M2()) {
1827 APInt data(8, operands[2]);
1828 value = APFloat(APFloat::Float8E5M2(), data);
1829 }
1830
1831 auto attr = opBuilder.getFloatAttr(floatType, value);
1832 if (isSpec) {
1833 createSpecConstant(unknownLoc, resultID, attr);
1834 } else {
1835 // For normal constants, we just record the attribute (and its type) for
1836 // later materialization at use sites.
1837 constantMap.try_emplace(resultID, attr, floatType);
1838 }
1839
1840 return success();
1841 }
1842
1843 return emitError(unknownLoc, "OpConstant can only generate values of "
1844 "scalar integer or floating-point type");
1845}
1846
1848 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1849 if (operands.size() != 2) {
1850 return emitError(unknownLoc, "Op")
1851 << (isSpec ? "Spec" : "") << "Constant"
1852 << (isTrue ? "True" : "False")
1853 << " must have type <id> and result <id>";
1854 }
1855
1856 auto attr = opBuilder.getBoolAttr(isTrue);
1857 auto resultID = operands[1];
1858 if (isSpec) {
1859 createSpecConstant(unknownLoc, resultID, attr);
1860 } else {
1861 // For normal constants, we just record the attribute (and its type) for
1862 // later materialization at use sites.
1863 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1864 }
1865
1866 return success();
1867}
1868
1869LogicalResult
1871 if (operands.size() < 2) {
1872 return emitError(unknownLoc,
1873 "OpConstantComposite must have type <id> and result <id>");
1874 }
1875 if (operands.size() < 3) {
1876 return emitError(unknownLoc,
1877 "OpConstantComposite must have at least 1 parameter");
1878 }
1879
1880 Type resultType = getType(operands[0]);
1881 if (!resultType) {
1882 return emitError(unknownLoc, "undefined result type from <id> ")
1883 << operands[0];
1884 }
1885
1887 elements.reserve(operands.size() - 2);
1888 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1889 auto elementInfo = getConstant(operands[i]);
1890 if (!elementInfo) {
1891 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1892 << operands[i] << " must come from a normal constant";
1893 }
1894 elements.push_back(elementInfo->first);
1895 }
1896
1897 auto resultID = operands[1];
1898 if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1899 SmallVector<Attribute> flattenedElems;
1900 for (Attribute element : elements) {
1901 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1902 for (auto value : denseElemAttr.getValues<Attribute>())
1903 flattenedElems.push_back(value);
1904 } else {
1905 flattenedElems.push_back(element);
1906 }
1907 }
1908 auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1909 constantMap.try_emplace(resultID, attr, tensorType);
1910 } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1911 auto attr = DenseElementsAttr::get(shapedType, elements);
1912 // For normal constants, we just record the attribute (and its type) for
1913 // later materialization at use sites.
1914 constantMap.try_emplace(resultID, attr, shapedType);
1915 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1916 auto attr = opBuilder.getArrayAttr(elements);
1917 constantMap.try_emplace(resultID, attr, resultType);
1918 } else {
1919 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1920 << resultType;
1921 }
1922
1923 return success();
1924}
1925
1927 ArrayRef<uint32_t> operands) {
1928 if (operands.size() != 3) {
1929 return emitError(
1930 unknownLoc,
1931 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1932 << operands.size();
1933 }
1934
1935 Type resultType = getType(operands[0]);
1936 if (!resultType) {
1937 return emitError(unknownLoc, "undefined result type from <id> ")
1938 << operands[0];
1939 }
1940
1941 auto compositeType = dyn_cast<CompositeType>(resultType);
1942 if (!compositeType) {
1943 return emitError(unknownLoc,
1944 "result type from <id> is not a composite type")
1945 << operands[0];
1946 }
1947
1948 uint32_t resultID = operands[1];
1949 uint32_t constantID = operands[2];
1950
1951 std::optional<std::pair<Attribute, Type>> constantInfo =
1952 getConstant(constantID);
1953 if (constantInfo.has_value()) {
1954 constantCompositeReplicateMap.try_emplace(
1955 resultID, constantInfo.value().first, resultType);
1956 return success();
1957 }
1958
1959 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1961 if (replicatedConstantCompositeInfo.has_value()) {
1962 constantCompositeReplicateMap.try_emplace(
1963 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1964 return success();
1965 }
1966
1967 return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1968 << constantID
1969 << " must come from a normal constant or a "
1970 "OpConstantCompositeReplicateEXT";
1971}
1972
1973LogicalResult
1975 if (operands.size() < 2) {
1976 return emitError(
1977 unknownLoc,
1978 "OpSpecConstantComposite must have type <id> and result <id>");
1979 }
1980 if (operands.size() < 3) {
1981 return emitError(unknownLoc,
1982 "OpSpecConstantComposite must have at least 1 parameter");
1983 }
1984
1985 Type resultType = getType(operands[0]);
1986 if (!resultType) {
1987 return emitError(unknownLoc, "undefined result type from <id> ")
1988 << operands[0];
1989 }
1990
1991 auto resultID = operands[1];
1992 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1993
1995 elements.reserve(operands.size() - 2);
1996 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1997 auto elementInfo = getSpecConstant(operands[i]);
1998 elements.push_back(SymbolRefAttr::get(elementInfo));
1999 }
2000
2001 auto op = spirv::SpecConstantCompositeOp::create(
2002 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
2003 opBuilder.getArrayAttr(elements));
2004 specConstCompositeMap[resultID] = op;
2005
2006 return success();
2007}
2008
2010 ArrayRef<uint32_t> operands) {
2011 if (operands.size() != 3) {
2012 return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
2013 "3 operands but found ")
2014 << operands.size();
2015 }
2016
2017 Type resultType = getType(operands[0]);
2018 if (!resultType) {
2019 return emitError(unknownLoc, "undefined result type from <id> ")
2020 << operands[0];
2021 }
2022
2023 auto compositeType = dyn_cast<CompositeType>(resultType);
2024 if (!compositeType) {
2025 return emitError(unknownLoc,
2026 "result type from <id> is not a composite type")
2027 << operands[0];
2028 }
2029
2030 uint32_t resultID = operands[1];
2031
2032 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
2033 spirv::SpecConstantOp constituentSpecConstantOp =
2034 getSpecConstant(operands[2]);
2035 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
2036 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
2037 SymbolRefAttr::get(constituentSpecConstantOp));
2038
2039 specConstCompositeReplicateMap[resultID] = op;
2040
2041 return success();
2042}
2043
2044LogicalResult
2046 if (operands.size() < 3)
2047 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
2048 "result <id>, and operand opcode");
2049
2050 uint32_t resultTypeID = operands[0];
2051
2052 if (!getType(resultTypeID))
2053 return emitError(unknownLoc, "undefined result type from <id> ")
2054 << resultTypeID;
2055
2056 uint32_t resultID = operands[1];
2057 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
2058 auto emplaceResult = specConstOperationMap.try_emplace(
2059 resultID,
2061 enclosedOpcode, resultTypeID,
2062 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
2063
2064 if (!emplaceResult.second)
2065 return emitError(unknownLoc, "value with <id>: ")
2066 << resultID << " is probably defined before.";
2067
2068 return success();
2069}
2070
2072 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2073 ArrayRef<uint32_t> enclosedOpOperands) {
2074
2075 Type resultType = getType(resultTypeID);
2076
2077 // Instructions wrapped by OpSpecConstantOp need an ID for their
2078 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
2079 // dialect wrapped op. For that purpose, a new value map is created and "fake"
2080 // ID in that map is assigned to the result of the enclosed instruction. Note
2081 // that there is no need to update this fake ID since we only need to
2082 // reference the created Value for the enclosed op from the spv::YieldOp
2083 // created later in this method (both of which are the only values in their
2084 // region: the SpecConstantOperation's region). If we encounter another
2085 // SpecConstantOperation in the module, we simply re-use the fake ID since the
2086 // previous Value assigned to it isn't visible in the current scope anyway.
2087 DenseMap<uint32_t, Value> newValueMap;
2088 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2089 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
2090
2091 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
2092 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2093 enclosedOpResultTypeAndOperands.push_back(fakeID);
2094 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2095 enclosedOpOperands.end());
2096
2097 // Process enclosed instruction before creating the enclosing
2098 // specConstantOperation (and its region). This way, references to constants,
2099 // global variables, and spec constants will be materialized outside the new
2100 // op's region. For more info, see Deserializer::getValue's implementation.
2101 if (failed(
2102 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2103 return Value();
2104
2105 // Since the enclosed op is emitted in the current block, split it in a
2106 // separate new block.
2107 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2108
2109 auto loc = createFileLineColLoc(opBuilder);
2110 auto specConstOperationOp =
2111 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2112
2113 Region &body = specConstOperationOp.getBody();
2114 // Move the new block into SpecConstantOperation's body.
2115 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2116 Region::iterator(enclosedBlock));
2117 Block &block = body.back();
2118
2119 // RAII guard to reset the insertion point to the module's region after
2120 // deserializing the body of the specConstantOperation.
2121 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2122 opBuilder.setInsertionPointToEnd(&block);
2123
2124 spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2125 return specConstOperationOp.getResult();
2126}
2127
2128LogicalResult
2130 if (operands.size() != 2) {
2131 return emitError(unknownLoc,
2132 "OpConstantNull must only have type <id> and result <id>");
2133 }
2134
2135 Type resultType = getType(operands[0]);
2136 if (!resultType) {
2137 return emitError(unknownLoc, "undefined result type from <id> ")
2138 << operands[0];
2139 }
2140
2141 auto resultID = operands[1];
2142 Attribute attr;
2143 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2144 attr = opBuilder.getZeroAttr(resultType);
2145 } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2146 if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2147 attr = DenseElementsAttr::get(tensorType, element);
2148 }
2149
2150 if (attr) {
2151 // For normal constants, we just record the attribute (and its type) for
2152 // later materialization at use sites.
2153 constantMap.try_emplace(resultID, attr, resultType);
2154 return success();
2155 }
2156
2157 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2158 << resultType;
2159}
2160
2161LogicalResult
2163 if (operands.size() < 3) {
2164 return emitError(unknownLoc)
2165 << "OpGraphConstantARM must have at least 2 operands";
2166 }
2167
2168 Type resultType = getType(operands[0]);
2169 if (!resultType) {
2170 return emitError(unknownLoc, "undefined result type from <id> ")
2171 << operands[0];
2172 }
2173
2174 uint32_t resultID = operands[1];
2175
2176 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2177 return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2178 }
2179
2180 APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2181 Type i32Ty = opBuilder.getIntegerType(32);
2182 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2183 graphConstantMap.try_emplace(
2184 resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2185
2186 return success();
2187}
2188
2189//===----------------------------------------------------------------------===//
2190// Control flow
2191//===----------------------------------------------------------------------===//
2192
2194 if (auto *block = getBlock(id)) {
2195 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2196 << " @ " << block << "\n");
2197 return block;
2198 }
2199
2200 // We don't know where this block will be placed finally (in a
2201 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2202 // function for now and sort out the proper place later.
2203 auto *block = curFunction->addBlock();
2204 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2205 << " @ " << block << "\n");
2206 return blockMap[id] = block;
2207}
2208
2210 if (!curBlock) {
2211 return emitError(unknownLoc, "OpBranch must appear inside a block");
2212 }
2213
2214 if (operands.size() != 1) {
2215 return emitError(unknownLoc, "OpBranch must take exactly one target label");
2216 }
2217
2218 auto *target = getOrCreateBlock(operands[0]);
2219 auto loc = createFileLineColLoc(opBuilder);
2220 // The preceding instruction for the OpBranch instruction could be an
2221 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2222 // the same OpLine information.
2223 spirv::BranchOp::create(opBuilder, loc, target);
2224
2226 return success();
2227}
2228
2229LogicalResult
2231 if (!curBlock) {
2232 return emitError(unknownLoc,
2233 "OpBranchConditional must appear inside a block");
2234 }
2235
2236 if (operands.size() != 3 && operands.size() != 5) {
2237 return emitError(unknownLoc,
2238 "OpBranchConditional must have condition, true label, "
2239 "false label, and optionally two branch weights");
2240 }
2241
2242 auto condition = getValue(operands[0]);
2243 auto *trueBlock = getOrCreateBlock(operands[1]);
2244 auto *falseBlock = getOrCreateBlock(operands[2]);
2245
2246 std::optional<std::pair<uint32_t, uint32_t>> weights;
2247 if (operands.size() == 5) {
2248 weights = std::make_pair(operands[3], operands[4]);
2249 }
2250 // The preceding instruction for the OpBranchConditional instruction could be
2251 // an OpSelectionMerge instruction, in this case they will have the same
2252 // OpLine information.
2253 auto loc = createFileLineColLoc(opBuilder);
2254 spirv::BranchConditionalOp::create(
2255 opBuilder, loc, condition, trueBlock,
2256 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2257 /*falseArguments=*/ArrayRef<Value>(), weights);
2258
2260 return success();
2261}
2262
2264 if (!curFunction) {
2265 return emitError(unknownLoc, "OpLabel must appear inside a function");
2266 }
2267
2268 if (operands.size() != 1) {
2269 return emitError(unknownLoc, "OpLabel should only have result <id>");
2270 }
2271
2272 auto labelID = operands[0];
2273 // We may have forward declared this block.
2274 auto *block = getOrCreateBlock(labelID);
2275 LLVM_DEBUG(logger.startLine()
2276 << "[block] populating block " << block << "\n");
2277 // If we have seen this block, make sure it was just a forward declaration.
2278 assert(block->empty() && "re-deserialize the same block!");
2279
2280 opBuilder.setInsertionPointToStart(block);
2281 blockMap[labelID] = curBlock = block;
2282
2283 return success();
2284}
2285
2286LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2287 if (!curGraph) {
2288 return emitError(unknownLoc, "a graph block must appear inside a graph");
2289 }
2290
2291 // We may have forward declared this block.
2292 Block *block = getOrCreateBlock(graphID);
2293 LLVM_DEBUG(logger.startLine()
2294 << "[block] populating block " << block << "\n");
2295 // If we have seen this block, make sure it was just a forward declaration.
2296 assert(block->empty() && "re-deserialize the same block!");
2297
2298 opBuilder.setInsertionPointToStart(block);
2299 blockMap[graphID] = curBlock = block;
2300
2301 return success();
2302}
2303
2304LogicalResult
2306 if (!curBlock) {
2307 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2308 }
2309
2310 if (operands.size() < 2) {
2311 return emitError(
2312 unknownLoc,
2313 "OpSelectionMerge must specify merge target and selection control");
2314 }
2315
2316 auto *mergeBlock = getOrCreateBlock(operands[0]);
2317 auto loc = createFileLineColLoc(opBuilder);
2318 auto selectionControl = operands[1];
2319
2320 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2321 .second) {
2322 return emitError(
2323 unknownLoc,
2324 "a block cannot have more than one OpSelectionMerge instruction");
2325 }
2326
2327 return success();
2328}
2329
2330LogicalResult
2332 if (!curBlock) {
2333 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2334 }
2335
2336 if (operands.size() < 3) {
2337 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2338 "continue target and loop control");
2339 }
2340
2341 auto *mergeBlock = getOrCreateBlock(operands[0]);
2342 auto *continueBlock = getOrCreateBlock(operands[1]);
2343 auto loc = createFileLineColLoc(opBuilder);
2344 uint32_t loopControl = operands[2];
2345
2346 if (!blockMergeInfo
2347 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2348 .second) {
2349 return emitError(
2350 unknownLoc,
2351 "a block cannot have more than one OpLoopMerge instruction");
2352 }
2353
2354 return success();
2355}
2356
2358 if (!curBlock) {
2359 return emitError(unknownLoc, "OpPhi must appear in a block");
2360 }
2361
2362 if (operands.size() < 4) {
2363 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2364 "and variable-parent pairs");
2365 }
2366
2367 // Create a block argument for this OpPhi instruction.
2368 Type blockArgType = getType(operands[0]);
2369 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2370 valueMap[operands[1]] = blockArg;
2371 LLVM_DEBUG(logger.startLine()
2372 << "[phi] created block argument " << blockArg
2373 << " id = " << operands[1] << " of type " << blockArgType << "\n");
2374
2375 // For each (value, predecessor) pair, insert the value to the predecessor's
2376 // blockPhiInfo entry so later we can fix the block argument there.
2377 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2378 uint32_t value = operands[i];
2379 Block *predecessor = getOrCreateBlock(operands[i + 1]);
2380 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2381 blockPhiInfo[predecessorTargetPair].push_back(value);
2382 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2383 << " with arg id = " << value << "\n");
2384 }
2385
2386 return success();
2387}
2388
2390 if (!curBlock)
2391 return emitError(unknownLoc, "OpSwitch must appear in a block");
2392
2393 if (operands.size() < 2)
2394 return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2395 "a default target");
2396
2397 if (operands.size() % 2)
2398 return emitError(unknownLoc,
2399 "OpSwitch must at have an even number of operands: "
2400 "selector, default target and any number of literal and "
2401 "label <id> pairs");
2402
2403 Value selector = getValue(operands[0]);
2404 Block *defaultBlock = getOrCreateBlock(operands[1]);
2405 Location loc = createFileLineColLoc(opBuilder);
2406
2407 SmallVector<int32_t> literals;
2408 SmallVector<Block *> blocks;
2409 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2410 literals.push_back(operands[i]);
2411 blocks.push_back(getOrCreateBlock(operands[i + 1]));
2412 }
2413
2414 SmallVector<ValueRange> targetOperands(blocks.size(), {});
2415 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2416 ArrayRef<Value>(), literals, blocks, targetOperands);
2417
2418 return success();
2419}
2420
2421namespace {
2422/// A class for putting all blocks in a structured selection/loop in a
2423/// spirv.mlir.selection/spirv.mlir.loop op.
2424class ControlFlowStructurizer {
2425public:
2426#ifndef NDEBUG
2427 ControlFlowStructurizer(Location loc, uint32_t control,
2428 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2429 Block *merge, Block *cont,
2430 llvm::ScopedPrinter &logger)
2431 : location(loc), control(control), blockMergeInfo(mergeInfo),
2432 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2433 logger(logger) {}
2434#else
2435 ControlFlowStructurizer(Location loc, uint32_t control,
2436 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2437 Block *merge, Block *cont)
2438 : location(loc), control(control), blockMergeInfo(mergeInfo),
2439 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2440#endif
2441
2442 /// Structurizes the loop at the given `headerBlock`.
2443 ///
2444 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2445 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2446 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2447 /// method will also update `mergeInfo` by remapping all blocks inside to the
2448 /// newly cloned ones inside structured control flow op's regions.
2449 LogicalResult structurize();
2450
2451private:
2452 /// Creates a new spirv.mlir.selection op at the beginning of the
2453 /// `mergeBlock`.
2454 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2455
2456 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2457 spirv::LoopOp createLoopOp(uint32_t loopControl);
2458
2459 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2460 void collectBlocksInConstruct();
2461
2462 Location location;
2463 uint32_t control;
2464
2465 spirv::BlockMergeInfoMap &blockMergeInfo;
2466
2467 Block *headerBlock;
2468 Block *mergeBlock;
2469 Block *continueBlock; // nullptr for spirv.mlir.selection
2470
2471 SetVector<Block *> constructBlocks;
2472
2473#ifndef NDEBUG
2474 /// A logger used to emit information during the deserialzation process.
2475 llvm::ScopedPrinter &logger;
2476#endif
2477};
2478} // namespace
2479
2480spirv::SelectionOp
2481ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2482 // Create a builder and set the insertion point to the beginning of the
2483 // merge block so that the newly created SelectionOp will be inserted there.
2484 OpBuilder builder(&mergeBlock->front());
2485
2486 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2487 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2488 selectionOp.addMergeBlock(builder);
2489
2490 return selectionOp;
2491}
2492
2493spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2494 // Create a builder and set the insertion point to the beginning of the
2495 // merge block so that the newly created LoopOp will be inserted there.
2496 OpBuilder builder(&mergeBlock->front());
2497
2498 auto control = static_cast<spirv::LoopControl>(loopControl);
2499 auto loopOp = spirv::LoopOp::create(builder, location, control);
2500 loopOp.addEntryAndMergeBlock(builder);
2501
2502 return loopOp;
2503}
2504
2505void ControlFlowStructurizer::collectBlocksInConstruct() {
2506 assert(constructBlocks.empty() && "expected empty constructBlocks");
2507
2508 // Put the header block in the work list first.
2509 constructBlocks.insert(headerBlock);
2510
2511 // For each item in the work list, add its successors excluding the merge
2512 // block.
2513 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2514 for (auto *successor : constructBlocks[i]->getSuccessors())
2515 if (successor != mergeBlock)
2516 constructBlocks.insert(successor);
2517 }
2518}
2519
2520LogicalResult ControlFlowStructurizer::structurize() {
2521 Operation *op = nullptr;
2522 bool isLoop = continueBlock != nullptr;
2523 if (isLoop) {
2524 if (auto loopOp = createLoopOp(control))
2525 op = loopOp.getOperation();
2526 } else {
2527 if (auto selectionOp = createSelectionOp(control))
2528 op = selectionOp.getOperation();
2529 }
2530 if (!op)
2531 return failure();
2532 Region &body = op->getRegion(0);
2533
2534 IRMapping mapper;
2535 // All references to the old merge block should be directed to the
2536 // selection/loop merge block in the SelectionOp/LoopOp's region.
2537 mapper.map(mergeBlock, &body.back());
2538
2539 collectBlocksInConstruct();
2540
2541 // We've identified all blocks belonging to the selection/loop's region. Now
2542 // need to "move" them into the selection/loop. Instead of really moving the
2543 // blocks, in the following we copy them and remap all values and branches.
2544 // This is because:
2545 // * Inserting a block into a region requires the block not in any region
2546 // before. But selections/loops can nest so we can create selection/loop ops
2547 // in a nested manner, which means some blocks may already be in a
2548 // selection/loop region when to be moved again.
2549 // * It's much trickier to fix up the branches into and out of the loop's
2550 // region: we need to treat not-moved blocks and moved blocks differently:
2551 // Not-moved blocks jumping to the loop header block need to jump to the
2552 // merge point containing the new loop op but not the loop continue block's
2553 // back edge. Moved blocks jumping out of the loop need to jump to the
2554 // merge block inside the loop region but not other not-moved blocks.
2555 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2556 // logic.
2557
2558 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2559 // block in this loop construct.
2560 OpBuilder builder(body);
2561 for (auto *block : constructBlocks) {
2562 // Create a block and insert it before the selection/loop merge block in the
2563 // SelectionOp/LoopOp's region.
2564 auto *newBlock = builder.createBlock(&body.back());
2565 mapper.map(block, newBlock);
2566 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2567 << " from block " << block << "\n");
2568 if (!isFnEntryBlock(block)) {
2569 for (BlockArgument blockArg : block->getArguments()) {
2570 auto newArg =
2571 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2572 mapper.map(blockArg, newArg);
2573 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2574 << blockArg << " to " << newArg << "\n");
2575 }
2576 } else {
2577 LLVM_DEBUG(logger.startLine()
2578 << "[cf] block " << block << " is a function entry block\n");
2579 }
2580
2581 for (auto &op : *block)
2582 newBlock->push_back(op.clone(mapper));
2583 }
2584
2585 // Go through all ops and remap the operands.
2586 auto remapOperands = [&](Operation *op) {
2587 for (auto &operand : op->getOpOperands())
2588 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2589 operand.set(mappedOp);
2590 for (auto &succOp : op->getBlockOperands())
2591 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2592 succOp.set(mappedOp);
2593 };
2594 for (auto &block : body)
2595 block.walk(remapOperands);
2596
2597 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2598 // the selection/loop construct into its region. Next we need to fix the
2599 // connections between this new SelectionOp/LoopOp with existing blocks.
2600
2601 // All existing incoming branches should go to the merge block, where the
2602 // SelectionOp/LoopOp resides right now.
2603 headerBlock->replaceAllUsesWith(mergeBlock);
2604
2605 LLVM_DEBUG({
2606 logger.startLine() << "[cf] after cloning and fixing references:\n";
2607 headerBlock->getParentOp()->print(logger.getOStream());
2608 logger.startLine() << "\n";
2609 });
2610
2611 if (isLoop) {
2612 if (!mergeBlock->args_empty()) {
2613 return mergeBlock->getParentOp()->emitError(
2614 "OpPhi in loop merge block unsupported");
2615 }
2616
2617 // The loop header block may have block arguments. Since now we place the
2618 // loop op inside the old merge block, we need to make sure the old merge
2619 // block has the same block argument list.
2620 for (BlockArgument blockArg : headerBlock->getArguments())
2621 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2622
2623 // If the loop header block has block arguments, make sure the spirv.Branch
2624 // op matches.
2625 SmallVector<Value, 4> blockArgs;
2626 if (!headerBlock->args_empty())
2627 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2628
2629 // The loop entry block should have a unconditional branch jumping to the
2630 // loop header block.
2631 builder.setInsertionPointToEnd(&body.front());
2632 spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2633 ArrayRef<Value>(blockArgs));
2634 }
2635
2636 // Values defined inside the selection region that need to be yielded outside
2637 // the region.
2638 SmallVector<Value> valuesToYield;
2639 // Outside uses of values that were sunk into the selection region. Those uses
2640 // will be replaced with values returned by the SelectionOp.
2641 SmallVector<Value> outsideUses;
2642
2643 // Move block arguments of the original block (`mergeBlock`) into the merge
2644 // block inside the selection (`body.back()`). Values produced by block
2645 // arguments will be yielded by the selection region. We do not update uses or
2646 // erase original block arguments yet. It will be done later in the code.
2647 //
2648 // Code below is not executed for loops as it would interfere with the logic
2649 // above. Currently block arguments in the merge block are not supported, but
2650 // instead, the code above copies those arguments from the header block into
2651 // the merge block. As such, running the code would yield those copied
2652 // arguments that is most likely not a desired behaviour. This may need to be
2653 // revisited in the future.
2654 if (!isLoop)
2655 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2656 // Create new block arguments in the last block ("merge block") of the
2657 // selection region. We create one argument for each argument in
2658 // `mergeBlock`. This new value will need to be yielded, and the original
2659 // value replaced, so add them to appropriate vectors.
2660 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2661 valuesToYield.push_back(body.back().getArguments().back());
2662 outsideUses.push_back(blockArg);
2663 }
2664
2665 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2666 // cleaned up.
2667 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2668 // First we need to drop all operands' references inside all blocks. This is
2669 // needed because we can have blocks referencing SSA values from one another.
2670 for (auto *block : constructBlocks)
2671 block->dropAllReferences();
2672
2673 // All internal uses should be removed from original blocks by now, so
2674 // whatever is left is an outside use and will need to be yielded from
2675 // the newly created selection / loop region.
2676 for (Block *block : constructBlocks) {
2677 for (Operation &op : *block) {
2678 if (!op.use_empty())
2679 for (Value result : op.getResults()) {
2680 valuesToYield.push_back(mapper.lookupOrNull(result));
2681 outsideUses.push_back(result);
2682 }
2683 }
2684 for (BlockArgument &arg : block->getArguments()) {
2685 if (!arg.use_empty()) {
2686 valuesToYield.push_back(mapper.lookupOrNull(arg));
2687 outsideUses.push_back(arg);
2688 }
2689 }
2690 }
2691
2692 assert(valuesToYield.size() == outsideUses.size());
2693
2694 // If we need to yield any values from the selection / loop region we will
2695 // take care of it here.
2696 if (!valuesToYield.empty()) {
2697 LLVM_DEBUG(logger.startLine()
2698 << "[cf] yielding values from the selection / loop region\n");
2699
2700 // Update `mlir.merge` with values to be yield.
2701 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2702 Operation *merge = llvm::getSingleElement(mergeOps);
2703 assert(merge);
2704 merge->setOperands(valuesToYield);
2705
2706 // MLIR does not allow changing the number of results of an operation, so
2707 // we create a new SelectionOp / LoopOp with required list of results and
2708 // move the region from the initial SelectionOp / LoopOp. The initial
2709 // operation is then removed. Since we move the region to the new op all
2710 // links between blocks and remapping we have previously done should be
2711 // preserved.
2712 builder.setInsertionPoint(&mergeBlock->front());
2713
2714 Operation *newOp = nullptr;
2715
2716 if (isLoop)
2717 newOp = spirv::LoopOp::create(builder, location,
2718 TypeRange(ValueRange(outsideUses)),
2719 static_cast<spirv::LoopControl>(control));
2720 else
2721 newOp = spirv::SelectionOp::create(
2722 builder, location, TypeRange(ValueRange(outsideUses)),
2723 static_cast<spirv::SelectionControl>(control));
2724
2725 newOp->getRegion(0).takeBody(body);
2726
2727 // Remove initial op and swap the pointer to the newly created one.
2728 op->erase();
2729 op = newOp;
2730
2731 // Update all outside uses to use results of the SelectionOp / LoopOp and
2732 // remove block arguments from the original merge block.
2733 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2734 outsideUses[i].replaceAllUsesWith(op->getResult(i));
2735
2736 // We do not support block arguments in loop merge block. Also running this
2737 // function with loop would break some of the loop specific code above
2738 // dealing with block arguments.
2739 if (!isLoop)
2740 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2741 }
2742
2743 // Check that whether some op in the to-be-erased blocks still has uses. Those
2744 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2745 // region. We cannot handle such cases given that once a value is sinked into
2746 // the SelectionOp/LoopOp's region, there is no escape for it.
2747 for (auto *block : constructBlocks) {
2748 if (!block->use_empty())
2749 return emitError(block->getParent()->getLoc(),
2750 "failed control flow structurization: "
2751 "block has uses outside of the "
2752 "enclosing selection/loop construct");
2753 for (Operation &op : *block)
2754 if (!op.use_empty())
2755 return op.emitOpError("failed control flow structurization: value has "
2756 "uses outside of the "
2757 "enclosing selection/loop construct");
2758 for (BlockArgument &arg : block->getArguments())
2759 if (!arg.use_empty())
2760 return emitError(arg.getLoc(), "failed control flow structurization: "
2761 "block argument has uses outside of the "
2762 "enclosing selection/loop construct");
2763 }
2764
2765 // Then erase all old blocks.
2766 for (auto *block : constructBlocks) {
2767 // We've cloned all blocks belonging to this construct into the structured
2768 // control flow op's region. Among these blocks, some may compose another
2769 // selection/loop. If so, they will be recorded within blockMergeInfo.
2770 // We need to update the pointers there to the newly remapped ones so we can
2771 // continue structurizing them later.
2772 //
2773 // We need to walk each block as constructBlocks do not include blocks
2774 // internal to ops already structured within those blocks. It is not
2775 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2776 // inside already structured selections/loops get invalidated and needs
2777 // updating, however the following example code can cause a crash (depending
2778 // on the structuring order), when the most inner selection is being
2779 // structured after the outer selection and loop have been already
2780 // structured:
2781 //
2782 // spirv.mlir.for {
2783 // // ...
2784 // spirv.mlir.selection {
2785 // // ..
2786 // // A selection region that hasn't been yet structured!
2787 // // ..
2788 // }
2789 // // ...
2790 // }
2791 //
2792 // If the loop gets structured after the outer selection, but before the
2793 // inner selection. Moving the already structured selection inside the loop
2794 // will invalidate the mergeInfo of the region that is not yet structured.
2795 // Just going over constructBlocks will not check and updated header blocks
2796 // inside the already structured selection region. Walking block fixes that.
2797 //
2798 // TODO: If structuring was done in a fixed order starting with inner
2799 // most constructs this most likely not be an issue and the whole code
2800 // section could be removed. However, with the current non-deterministic
2801 // order this is not possible.
2802 //
2803 // TODO: The asserts in the following assumes input SPIR-V blob forms
2804 // correctly nested selection/loop constructs. We should relax this and
2805 // support error cases better.
2806 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2807 auto it = blockMergeInfo.find(block);
2808 if (it != blockMergeInfo.end()) {
2809 // Use the original location for nested selection/loop ops.
2810 Location loc = it->second.loc;
2811
2812 Block *newHeader = mapper.lookupOrNull(block);
2813 if (!newHeader)
2814 return emitError(loc, "failed control flow structurization: nested "
2815 "loop header block should be remapped!");
2816
2817 Block *newContinue = it->second.continueBlock;
2818 if (newContinue) {
2819 newContinue = mapper.lookupOrNull(newContinue);
2820 if (!newContinue)
2821 return emitError(loc, "failed control flow structurization: nested "
2822 "loop continue block should be remapped!");
2823 }
2824
2825 Block *newMerge = it->second.mergeBlock;
2826 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2827 newMerge = mappedTo;
2828
2829 // The iterator should be erased before adding a new entry into
2830 // blockMergeInfo to avoid iterator invalidation.
2831 blockMergeInfo.erase(it);
2832 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2833 newContinue);
2834 }
2835
2836 return WalkResult::advance();
2837 };
2838
2839 if (block->walk(updateMergeInfo).wasInterrupted())
2840 return failure();
2841
2842 // The structured selection/loop's entry block does not have arguments.
2843 // If the function's header block is also part of the structured control
2844 // flow, we cannot just simply erase it because it may contain arguments
2845 // matching the function signature and used by the cloned blocks.
2846 if (isFnEntryBlock(block)) {
2847 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2848 << " to only contain a spirv.Branch op\n");
2849 // Still keep the function entry block for the potential block arguments,
2850 // but replace all ops inside with a branch to the merge block.
2851 block->clear();
2852 builder.setInsertionPointToEnd(block);
2853 spirv::BranchOp::create(builder, location, mergeBlock);
2854 } else {
2855 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2856 block->erase();
2857 }
2858 }
2859
2860 LLVM_DEBUG(logger.startLine()
2861 << "[cf] after structurizing construct with header block "
2862 << headerBlock << ":\n"
2863 << *op << "\n");
2864
2865 return success();
2866}
2867
2869 LLVM_DEBUG({
2870 logger.startLine()
2871 << "//----- [phi] start wiring up block arguments -----//\n";
2872 logger.indent();
2873 });
2874
2875 OpBuilder::InsertionGuard guard(opBuilder);
2876
2877 for (const auto &info : blockPhiInfo) {
2878 Block *block = info.first.first;
2879 Block *target = info.first.second;
2880 const BlockPhiInfo &phiInfo = info.second;
2881 LLVM_DEBUG({
2882 logger.startLine() << "[phi] block " << block << "\n";
2883 logger.startLine() << "[phi] before creating block argument:\n";
2884 block->getParentOp()->print(logger.getOStream());
2885 logger.startLine() << "\n";
2886 });
2887
2888 // Set insertion point to before this block's terminator early because we
2889 // may materialize ops via getValue() call.
2890 auto *op = block->getTerminator();
2891 opBuilder.setInsertionPoint(op);
2892
2893 SmallVector<Value, 4> blockArgs;
2894 blockArgs.reserve(phiInfo.size());
2895 for (uint32_t valueId : phiInfo) {
2896 if (Value value = getValue(valueId)) {
2897 blockArgs.push_back(value);
2898 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2899 << " id = " << valueId << "\n");
2900 } else {
2901 return emitError(unknownLoc, "OpPhi references undefined value!");
2902 }
2903 }
2904
2905 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2906 // Replace the previous branch op with a new one with block arguments.
2907 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2908 branchOp.getTarget(), blockArgs);
2909 branchOp.erase();
2910 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2911 assert((branchCondOp.getTrueBlock() == target ||
2912 branchCondOp.getFalseBlock() == target) &&
2913 "expected target to be either the true or false target");
2914 if (target == branchCondOp.getTrueTarget())
2915 spirv::BranchConditionalOp::create(
2916 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2917 blockArgs, branchCondOp.getFalseBlockArguments(),
2918 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2919 branchCondOp.getFalseTarget());
2920 else
2921 spirv::BranchConditionalOp::create(
2922 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2923 branchCondOp.getTrueBlockArguments(), blockArgs,
2924 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2925 branchCondOp.getFalseBlock());
2926
2927 branchCondOp.erase();
2928 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2929 if (target == switchOp.getDefaultTarget()) {
2930 SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
2931 DenseIntElementsAttr literals =
2932 switchOp.getLiterals().value_or(DenseIntElementsAttr());
2933 spirv::SwitchOp::create(
2934 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2935 switchOp.getDefaultTarget(), blockArgs, literals,
2936 switchOp.getTargets(), targetOperands);
2937 switchOp.erase();
2938 } else {
2939 SuccessorRange targets = switchOp.getTargets();
2940 auto it = llvm::find(targets, target);
2941 assert(it != targets.end());
2942 size_t index = std::distance(targets.begin(), it);
2943 switchOp.getTargetOperandsMutable(index).assign(blockArgs);
2944 }
2945 } else {
2946 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2947 }
2948
2949 LLVM_DEBUG({
2950 logger.startLine() << "[phi] after creating block argument:\n";
2951 block->getParentOp()->print(logger.getOStream());
2952 logger.startLine() << "\n";
2953 });
2954 }
2955 blockPhiInfo.clear();
2956
2957 LLVM_DEBUG({
2958 logger.unindent();
2959 logger.startLine()
2960 << "//--- [phi] completed wiring up block arguments ---//\n";
2961 });
2962 return success();
2963}
2964
2966 // Create a copy, so we can modify keys in the original.
2967 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2968 for (auto [block, mergeInfo] : blockMergeInfoCopy) {
2969 // Skip processing loop regions. For loop regions continueBlock is non-null.
2970 if (mergeInfo.continueBlock)
2971 continue;
2972
2973 if (!block->mightHaveTerminator())
2974 continue;
2975
2976 Operation *terminator = block->getTerminator();
2977 assert(terminator);
2978
2979 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2980 continue;
2981
2982 // Check if the current header block is a merge block of another construct.
2983 bool splitHeaderMergeBlock = false;
2984 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2985 if (mergeInfo.mergeBlock == block)
2986 splitHeaderMergeBlock = true;
2987 }
2988
2989 // Do not split a block that only contains a conditional branch / switch,
2990 // unless it is also a merge block of another construct - in that case we
2991 // want to split the block. We do not want two constructs to share header /
2992 // merge block.
2993 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2994 Block *newBlock = block->splitBlock(terminator);
2995 OpBuilder builder(block, block->end());
2996 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2997
2998 // After splitting we need to update the map to use the new block as a
2999 // header.
3000 blockMergeInfo.erase(block);
3001 blockMergeInfo.try_emplace(newBlock, mergeInfo);
3002 }
3003 }
3004
3005 return success();
3006}
3007
3009 if (!options.enableControlFlowStructurization) {
3010 LLVM_DEBUG(
3011 {
3012 logger.startLine()
3013 << "//----- [cf] skip structurizing control flow -----//\n";
3014 logger.indent();
3015 });
3016 return success();
3017 }
3018
3019 LLVM_DEBUG({
3020 logger.startLine()
3021 << "//----- [cf] start structurizing control flow -----//\n";
3022 logger.indent();
3023 });
3024
3025 LLVM_DEBUG({
3026 logger.startLine() << "[cf] split conditional blocks\n";
3027 logger.startLine() << "\n";
3028 });
3029
3030 if (failed(splitSelectionHeader())) {
3031 return failure();
3032 }
3033
3034 while (!blockMergeInfo.empty()) {
3035 Block *headerBlock = blockMergeInfo.begin()->first;
3036 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
3037
3038 LLVM_DEBUG({
3039 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
3040 headerBlock->print(logger.getOStream());
3041 logger.startLine() << "\n";
3042 });
3043
3044 auto *mergeBlock = mergeInfo.mergeBlock;
3045 assert(mergeBlock && "merge block cannot be nullptr");
3046 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
3047 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
3048 LLVM_DEBUG({
3049 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
3050 mergeBlock->print(logger.getOStream());
3051 logger.startLine() << "\n";
3052 });
3053
3054 auto *continueBlock = mergeInfo.continueBlock;
3055 LLVM_DEBUG(if (continueBlock) {
3056 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
3057 continueBlock->print(logger.getOStream());
3058 logger.startLine() << "\n";
3059 });
3060 // Erase this case before calling into structurizer, who will update
3061 // blockMergeInfo.
3062 blockMergeInfo.erase(blockMergeInfo.begin());
3063 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
3064 blockMergeInfo, headerBlock,
3065 mergeBlock, continueBlock
3066#ifndef NDEBUG
3067 ,
3068 logger
3069#endif
3070 );
3071 if (failed(structurizer.structurize()))
3072 return failure();
3073 }
3074
3075 LLVM_DEBUG({
3076 logger.unindent();
3077 logger.startLine()
3078 << "//--- [cf] completed structurizing control flow ---//\n";
3079 });
3080 return success();
3081}
3082
3083//===----------------------------------------------------------------------===//
3084// Debug
3085//===----------------------------------------------------------------------===//
3086
3088 if (!debugLine)
3089 return unknownLoc;
3090
3091 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3092 if (fileName.empty())
3093 fileName = "<unknown>";
3094 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
3095 debugLine->column);
3096}
3097
3098LogicalResult
3100 // According to SPIR-V spec:
3101 // "This location information applies to the instructions physically
3102 // following this instruction, up to the first occurrence of any of the
3103 // following: the next end of block, the next OpLine instruction, or the next
3104 // OpNoLine instruction."
3105 if (operands.size() != 3)
3106 return emitError(unknownLoc, "OpLine must have 3 operands");
3107 debugLine = DebugLine{operands[0], operands[1], operands[2]};
3108 return success();
3109}
3110
3111void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
3112
3113LogicalResult
3115 if (operands.size() < 2)
3116 return emitError(unknownLoc, "OpString needs at least 2 operands");
3117
3118 if (!debugInfoMap.lookup(operands[0]).empty())
3119 return emitError(unknownLoc,
3120 "duplicate debug string found for result <id> ")
3121 << operands[0];
3122
3123 unsigned wordIndex = 1;
3124 StringRef debugString = decodeStringLiteral(operands, wordIndex);
3125 if (wordIndex != operands.size())
3126 return emitError(unknownLoc,
3127 "unexpected trailing words in OpString instruction");
3128
3129 debugInfoMap[operands[0]] = debugString;
3130 return success();
3131}
return success()
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
iterator begin()
Definition Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< BlockOperand > getBlockOperands()
Definition Operation.h:721
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:878
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
void print(raw_ostream &os, const OpPrintingFlags &flags={})
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:66
result_range getResults()
Definition Operation.h:441
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
iterator end()
Definition Region.h:56
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:252
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processSamplerType(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
LogicalResult processNamedBarrierType(ArrayRef< uint32_t > operands)
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:148
static MatrixType get(Type columnType, uint32_t columnCount)
static NamedBarrierType get(MLIRContext *context)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.