MLIR 22.0.0git
Deserializer.cpp
Go to the documentation of this file.
1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(unknownLoc,
120 "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(unknownLoc, "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
142#undef MIN_VERSION_CASE
143 default:
144 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145 << minorVersion;
146 }
147 } else {
148 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149 << majorVersion;
150 }
151
152 // TODO: generator number, bound, schema
153 curOffset = spirv::kHeaderWordCount;
154 return success();
155}
156
157LogicalResult
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc, "OpCapability must have one parameter");
161
162 auto cap = spirv::symbolizeCapability(operands[0]);
163 if (!cap)
164 return emitError(unknownLoc, "unknown capability: ") << operands[0];
165
166 capabilities.insert(*cap);
167 return success();
168}
169
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171 if (words.empty()) {
172 return emitError(
173 unknownLoc,
174 "OpExtension must have a literal string for the extension name");
175 }
176
177 unsigned wordIndex = 0;
178 StringRef extName = decodeStringLiteral(words, wordIndex);
179 if (wordIndex != words.size())
180 return emitError(unknownLoc,
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
183 if (!ext)
184 return emitError(unknownLoc, "unknown extension: ") << extName;
185
186 extensions.insert(*ext);
187 return success();
188}
189
190LogicalResult
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
193 return emitError(unknownLoc,
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
196 }
197
198 unsigned wordIndex = 1;
199 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200 if (wordIndex != words.size()) {
201 return emitError(unknownLoc,
202 "unexpected trailing words in OpExtInstImport");
203 }
204 return success();
205}
206
207void spirv::Deserializer::attachVCETriple() {
208 (*module)->setAttr(
209 spirv::ModuleOp::getVCETripleAttrName(),
210 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211 extensions.getArrayRef(), context));
212}
213
214LogicalResult
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc, "OpMemoryModel must have two operands");
218
219 (*module)->setAttr(
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel>(operands.front())));
223
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel>(operands.back())));
227
228 return success();
229}
230
231template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233 Location loc, OpBuilder &opBuilder,
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc, "OpDecoration with ")
238 << decorationName << "needs a cache control integer literal and a "
239 << cacheControlKind << " cache control literal";
240 }
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 if (auto attrList =
246 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250 return success();
251}
252
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254 // TODO: This function should also be auto-generated. For now, since only a
255 // few decorations are processed/handled in a meaningful manner, going with a
256 // manual implementation.
257 if (words.size() < 2) {
258 return emitError(
259 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260 }
261 auto decorationName =
262 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265 }
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (static_cast<spirv::Decoration>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc, "OpDecorate with ")
271 << decorationName << " needs a single integer literal";
272 }
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode>(words[2])));
276 break;
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc, "OpDecorate with ")
280 << decorationName << " needs a single integer literal";
281 }
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode>(words[2])));
285 break;
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc, "OpDecorate with ")
290 << decorationName << " needs a single integer literal";
291 }
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
294 break;
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc, "OpDecorate with ")
298 << decorationName << " needs a single integer literal";
299 }
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
303 break;
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc, "OpDecorate with ")
307 << decorationName << " needs a single integer literal";
308 }
309 typeDecorations[words[0]] = words[2];
310 break;
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc, "OpDecorate with ")
314 << decorationName
315 << " needs at least 1 string and 1 integer literal";
316 }
317 // LinkageAttributes has two parameters ["linkageName", linkageType]
318 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
319 // "linkageName" is a stringliteral encoded as uint32_t,
320 // hence the size of name is variable length which results in words.size()
321 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
322 // 3 + ceildiv(strlen(name), 4).
323 unsigned wordIndex = 2;
324 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 StringAttr::get(context, linkageName), linkageTypeAttr);
329 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
330 break;
331 }
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 if (words.size() != 2) {
350 return emitError(unknownLoc, "OpDecoration with ")
351 << decorationName << "needs a single target <id>";
352 }
353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354 break;
355 case spirv::Decoration::Location:
356 case spirv::Decoration::SpecId:
357 if (words.size() != 3) {
358 return emitError(unknownLoc, "OpDecoration with ")
359 << decorationName << "needs a single integer literal";
360 }
361 decorations[words[0]].set(
362 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363 break;
364 case spirv::Decoration::CacheControlLoadINTEL: {
365 LogicalResult res = deserializeCacheControlDecoration<
366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368 "load");
369 if (failed(res))
370 return res;
371 break;
372 }
373 case spirv::Decoration::CacheControlStoreINTEL: {
374 LogicalResult res = deserializeCacheControlDecoration<
375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377 "store");
378 if (failed(res))
379 return res;
380 break;
381 }
382 default:
383 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
384 }
385 return success();
386}
387
388LogicalResult
389spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391 if (words.size() < 3) {
392 return emitError(unknownLoc,
393 "OpMemberDecorate must have at least 3 operands");
394 }
395
396 auto decoration = static_cast<spirv::Decoration>(words[2]);
397 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398 return emitError(unknownLoc,
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
401 }
402 ArrayRef<uint32_t> decorationOperands;
403 if (words.size() > 3) {
404 decorationOperands = words.slice(3);
405 }
406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407 return success();
408}
409
410LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411 if (words.size() < 3) {
412 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
413 }
414 unsigned wordIndex = 2;
415 auto name = decodeStringLiteral(words, wordIndex);
416 if (wordIndex != words.size()) {
417 return emitError(unknownLoc,
418 "unexpected trailing words in OpMemberName instruction");
419 }
420 memberNameMap[words[0]][words[1]] = name;
421 return success();
422}
423
425 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426 if (!decorations.contains(argID)) {
427 argAttrs[argIndex] = DictionaryAttr::get(context, {});
428 return success();
429 }
430
431 spirv::DecorationAttr foundDecorationAttr;
432 for (NamedAttribute decAttr : decorations[argID]) {
433 for (auto decoration :
434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435 spirv::Decoration::AliasedPointer,
436 spirv::Decoration::RestrictPointer}) {
437
438 if (decAttr.getName() !=
439 getSymbolDecoration(stringifyDecoration(decoration)))
440 continue;
441
442 if (foundDecorationAttr)
443 return emitError(unknownLoc,
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
446 << argID;
447
448 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449 break;
450 }
451
452 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
453 spirv::Decoration::RelaxedPrecision))) {
454 // TODO: Current implementation supports only one decoration per function
455 // parameter so RelaxedPrecision cannot be applied at the same time as,
456 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
457 // combination of decoration allowed by the spec to be supported.
458 if (foundDecorationAttr)
459 return emitError(unknownLoc, "already found a decoration for function "
460 "argument with result <id> ")
461 << argID;
462
463 foundDecorationAttr = spirv::DecorationAttr::get(
464 context, spirv::Decoration::RelaxedPrecision);
465 }
466 }
467
468 if (!foundDecorationAttr)
469 return emitError(unknownLoc, "unimplemented decoration support for "
470 "function argument with result <id> ")
471 << argID;
472
473 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
474 foundDecorationAttr);
475 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
476 return success();
477}
478
479LogicalResult
481 if (curFunction) {
482 return emitError(unknownLoc, "found function inside function");
483 }
484
485 // Get the result type
486 if (operands.size() != 4) {
487 return emitError(unknownLoc, "OpFunction must have 4 parameters");
488 }
489 Type resultType = getType(operands[0]);
490 if (!resultType) {
491 return emitError(unknownLoc, "undefined result type from <id> ")
492 << operands[0];
493 }
494
495 uint32_t fnID = operands[1];
496 if (funcMap.count(fnID)) {
497 return emitError(unknownLoc, "duplicate function definition/declaration");
498 }
499
500 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
501 if (!fnControl) {
502 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
503 }
504
505 Type fnType = getType(operands[3]);
506 if (!fnType || !isa<FunctionType>(fnType)) {
507 return emitError(unknownLoc, "unknown function type from <id> ")
508 << operands[3];
509 }
510 auto functionType = cast<FunctionType>(fnType);
511
512 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
513 (functionType.getNumResults() == 1 &&
514 functionType.getResult(0) != resultType)) {
515 return emitError(unknownLoc, "mismatch in function type ")
516 << functionType << " and return type " << resultType << " specified";
517 }
518
519 std::string fnName = getFunctionSymbol(fnID);
520 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
521 functionType, fnControl.value());
522 // Processing other function attributes.
523 if (decorations.count(fnID)) {
524 for (auto attr : decorations[fnID].getAttrs()) {
525 funcOp->setAttr(attr.getName(), attr.getValue());
526 }
527 }
528 curFunction = funcMap[fnID] = funcOp;
529 auto *entryBlock = funcOp.addEntryBlock();
530 LLVM_DEBUG({
531 logger.startLine()
532 << "//===-------------------------------------------===//\n";
533 logger.startLine() << "[fn] name: " << fnName << "\n";
534 logger.startLine() << "[fn] type: " << fnType << "\n";
535 logger.startLine() << "[fn] ID: " << fnID << "\n";
536 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
537 logger.indent();
538 });
539
540 SmallVector<Attribute> argAttrs;
541 argAttrs.resize(functionType.getNumInputs());
542
543 // Parse the op argument instructions
544 if (functionType.getNumInputs()) {
545 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
546 auto argType = functionType.getInput(i);
547 spirv::Opcode opcode = spirv::Opcode::OpNop;
548 ArrayRef<uint32_t> operands;
549 if (failed(sliceInstruction(opcode, operands,
550 spirv::Opcode::OpFunctionParameter))) {
551 return failure();
552 }
553 if (opcode != spirv::Opcode::OpFunctionParameter) {
554 return emitError(
555 unknownLoc,
556 "missing OpFunctionParameter instruction for argument ")
557 << i;
558 }
559 if (operands.size() != 2) {
560 return emitError(
561 unknownLoc,
562 "expected result type and result <id> for OpFunctionParameter");
563 }
564 auto argDefinedType = getType(operands[0]);
565 if (!argDefinedType || argDefinedType != argType) {
566 return emitError(unknownLoc,
567 "mismatch in argument type between function type "
568 "definition ")
569 << functionType << " and argument type definition "
570 << argDefinedType << " at argument " << i;
571 }
572 if (getValue(operands[1])) {
573 return emitError(unknownLoc, "duplicate definition of result <id> ")
574 << operands[1];
575 }
576 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
577 return failure();
578 }
579
580 auto argValue = funcOp.getArgument(i);
581 valueMap[operands[1]] = argValue;
582 }
583 }
584
585 if (llvm::any_of(argAttrs, [](Attribute attr) {
586 auto argAttr = cast<DictionaryAttr>(attr);
587 return !argAttr.empty();
588 }))
589 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
590
591 // entryBlock is needed to access the arguments, Once that is done, we can
592 // erase the block for functions with 'Import' LinkageAttributes, since these
593 // are essentially function declarations, so they have no body.
594 auto linkageAttr = funcOp.getLinkageAttributes();
595 auto hasImportLinkage =
596 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
597 spirv::LinkageType::Import);
598 if (hasImportLinkage)
599 funcOp.eraseBody();
600
601 // RAII guard to reset the insertion point to the module's region after
602 // deserializing the body of this function.
603 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
604
605 spirv::Opcode opcode = spirv::Opcode::OpNop;
606 ArrayRef<uint32_t> instOperands;
607
608 // Special handling for the entry block. We need to make sure it starts with
609 // an OpLabel instruction. The entry block takes the same parameters as the
610 // function. All other blocks do not take any parameter. We have already
611 // created the entry block, here we need to register it to the correct label
612 // <id>.
613 if (failed(sliceInstruction(opcode, instOperands,
614 spirv::Opcode::OpFunctionEnd))) {
615 return failure();
616 }
617 if (opcode == spirv::Opcode::OpFunctionEnd) {
618 return processFunctionEnd(instOperands);
619 }
620 if (opcode != spirv::Opcode::OpLabel) {
621 return emitError(unknownLoc, "a basic block must start with OpLabel");
622 }
623 if (instOperands.size() != 1) {
624 return emitError(unknownLoc, "OpLabel should only have result <id>");
625 }
626 blockMap[instOperands[0]] = entryBlock;
627 if (failed(processLabel(instOperands))) {
628 return failure();
629 }
630
631 // Then process all the other instructions in the function until we hit
632 // OpFunctionEnd.
633 while (succeeded(sliceInstruction(opcode, instOperands,
634 spirv::Opcode::OpFunctionEnd)) &&
635 opcode != spirv::Opcode::OpFunctionEnd) {
636 if (failed(processInstruction(opcode, instOperands))) {
637 return failure();
638 }
639 }
640 if (opcode != spirv::Opcode::OpFunctionEnd) {
641 return failure();
642 }
643
644 return processFunctionEnd(instOperands);
645}
646
647LogicalResult
649 // Process OpFunctionEnd.
650 if (!operands.empty()) {
651 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
652 }
653
654 // Wire up block arguments from OpPhi instructions.
655 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
656 // ops.
657 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
658 return failure();
659 }
660
661 curBlock = nullptr;
662 curFunction = std::nullopt;
663
664 LLVM_DEBUG({
665 logger.unindent();
666 logger.startLine()
667 << "//===-------------------------------------------===//\n";
668 });
669 return success();
670}
671
672LogicalResult
674 if (operands.size() < 2) {
675 return emitError(unknownLoc,
676 "missing graph defintion in OpGraphEntryPointARM");
677 }
678
679 unsigned wordIndex = 0;
680 uint32_t graphID = operands[wordIndex++];
681 if (!graphMap.contains(graphID)) {
682 return emitError(unknownLoc,
683 "missing graph definition/declaration with id ")
684 << graphID;
685 }
686
687 spirv::GraphARMOp graphARM = graphMap[graphID];
688 StringRef name = decodeStringLiteral(operands, wordIndex);
689 graphARM.setSymName(name);
690 graphARM.setEntryPoint(true);
691
693 for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
694 if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
695 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
696 } else {
697 return emitError(unknownLoc, "undefined result <id> ")
698 << operands[wordIndex] << " while decoding OpGraphEntryPoint";
699 }
700 }
701
702 // RAII guard to reset the insertion point to previous value when done.
703 OpBuilder::InsertionGuard insertionGuard(opBuilder);
704 opBuilder.setInsertionPoint(graphARM);
705 spirv::GraphEntryPointARMOp::create(
706 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
707 opBuilder.getArrayAttr(interface));
708
709 return success();
710}
711
712LogicalResult
714 if (curGraph) {
715 return emitError(unknownLoc, "found graph inside graph");
716 }
717 // Get the result type.
718 if (operands.size() < 2) {
719 return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
720 }
721
722 Type type = getType(operands[0]);
723 if (!type || !isa<GraphType>(type)) {
724 return emitError(unknownLoc, "unknown graph type from <id> ")
725 << operands[0];
726 }
727 auto graphType = cast<GraphType>(type);
728 if (graphType.getNumResults() <= 0) {
729 return emitError(unknownLoc, "expected at least one result");
730 }
731
732 uint32_t graphID = operands[1];
733 if (graphMap.count(graphID)) {
734 return emitError(unknownLoc, "duplicate graph definition/declaration");
735 }
736
737 std::string graphName = getGraphSymbol(graphID);
738 auto graphOp =
739 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
740 curGraph = graphMap[graphID] = graphOp;
741 Block *entryBlock = graphOp.addEntryBlock();
742 LLVM_DEBUG({
743 logger.startLine()
744 << "//===-------------------------------------------===//\n";
745 logger.startLine() << "[graph] name: " << graphName << "\n";
746 logger.startLine() << "[graph] type: " << graphType << "\n";
747 logger.startLine() << "[graph] ID: " << graphID << "\n";
748 logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
749 logger.indent();
750 });
751
752 // Parse the op argument instructions.
753 for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
754 spirv::Opcode opcode;
755 ArrayRef<uint32_t> operands;
756 if (failed(sliceInstruction(opcode, operands,
757 spirv::Opcode::OpGraphInputARM))) {
758 return failure();
759 }
760 if (operands.size() != 3) {
761 return emitError(unknownLoc, "expected result type, result <id> and "
762 "input index for OpGraphInputARM");
763 }
764
765 Type argDefinedType = getType(operands[0]);
766 if (!argDefinedType) {
767 return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
768 }
769
770 if (argDefinedType != argType) {
771 return emitError(unknownLoc,
772 "mismatch in argument type between graph type "
773 "definition ")
774 << graphType << " and argument type definition " << argDefinedType
775 << " at argument " << index;
776 }
777 if (getValue(operands[1])) {
778 return emitError(unknownLoc, "duplicate definition of result <id> ")
779 << operands[1];
780 }
781
782 IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
783 if (!inputIndexAttr) {
784 return emitError(unknownLoc,
785 "unable to read inputIndex value from constant op ")
786 << operands[2];
787 }
788 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
789 valueMap[operands[1]] = argValue;
790 }
791
792 graphOutputs.resize(graphType.getNumResults());
793
794 // RAII guard to reset the insertion point to the module's region after
795 // deserializing the body of this function.
796 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
797
798 blockMap[graphID] = entryBlock;
799 if (failed(createGraphBlock(graphID))) {
800 return failure();
801 }
802
803 // Process all the instructions in the graph until and including
804 // OpGraphEndARM.
805 spirv::Opcode opcode;
806 ArrayRef<uint32_t> instOperands;
807 do {
808 if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
809 return failure();
810 }
811
812 if (failed(processInstruction(opcode, instOperands))) {
813 return failure();
814 }
815 } while (opcode != spirv::Opcode::OpGraphEndARM);
816
817 return success();
818}
819
820LogicalResult
822 if (operands.size() != 2) {
823 return emitError(
824 unknownLoc,
825 "expected value id and output index for OpGraphSetOutputARM");
826 }
827
828 uint32_t id = operands[0];
829 Value value = getValue(id);
830 if (!value) {
831 return emitError(unknownLoc, "could not find result <id> ") << id;
832 }
833
834 IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
835 if (!outputIndexAttr) {
836 return emitError(unknownLoc,
837 "unable to read outputIndex value from constant op ")
838 << operands[1];
839 }
840 graphOutputs[outputIndexAttr.getInt()] = value;
841 return success();
842}
843
844LogicalResult
846 // Create GraphOutputsARM instruction.
847 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
848
849 // Process OpGraphEndARM.
850 if (!operands.empty()) {
851 return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
852 }
853
854 curBlock = nullptr;
855 curGraph = std::nullopt;
856 graphOutputs.clear();
857
858 LLVM_DEBUG({
859 logger.unindent();
860 logger.startLine()
861 << "//===-------------------------------------------===//\n";
862 });
863 return success();
864}
865
866std::optional<std::pair<Attribute, Type>>
868 auto constIt = constantMap.find(id);
869 if (constIt == constantMap.end())
870 return std::nullopt;
871 return constIt->getSecond();
872}
873
874std::optional<std::pair<Attribute, Type>>
876 if (auto it = constantCompositeReplicateMap.find(id);
877 it != constantCompositeReplicateMap.end())
878 return it->second;
879 return std::nullopt;
880}
881
882std::optional<spirv::SpecConstOperationMaterializationInfo>
884 auto constIt = specConstOperationMap.find(id);
885 if (constIt == specConstOperationMap.end())
886 return std::nullopt;
887 return constIt->getSecond();
888}
889
891 auto funcName = nameMap.lookup(id).str();
892 if (funcName.empty()) {
893 funcName = "spirv_fn_" + std::to_string(id);
894 }
895 return funcName;
896}
897
898std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
899 std::string graphName = nameMap.lookup(id).str();
900 if (graphName.empty()) {
901 graphName = "spirv_graph_" + std::to_string(id);
902 }
903 return graphName;
904}
905
907 auto constName = nameMap.lookup(id).str();
908 if (constName.empty()) {
909 constName = "spirv_spec_const_" + std::to_string(id);
910 }
911 return constName;
912}
913
914spirv::SpecConstantOp
916 TypedAttr defaultValue) {
917 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
918 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
919 defaultValue);
920 if (decorations.count(resultID)) {
921 for (auto attr : decorations[resultID].getAttrs())
922 op->setAttr(attr.getName(), attr.getValue());
923 }
924 specConstMap[resultID] = op;
925 return op;
926}
927
928std::optional<spirv::GraphConstantARMOpMaterializationInfo>
930 auto graphConstIt = graphConstantMap.find(id);
931 if (graphConstIt == graphConstantMap.end())
932 return std::nullopt;
933 return graphConstIt->getSecond();
934}
935
936LogicalResult
938 unsigned wordIndex = 0;
939 if (operands.size() < 3) {
940 return emitError(
941 unknownLoc,
942 "OpVariable needs at least 3 operands, type, <id> and storage class");
943 }
944
945 // Result Type.
946 auto type = getType(operands[wordIndex]);
947 if (!type) {
948 return emitError(unknownLoc, "unknown result type <id> : ")
949 << operands[wordIndex];
950 }
951 auto ptrType = dyn_cast<spirv::PointerType>(type);
952 if (!ptrType) {
953 return emitError(unknownLoc,
954 "expected a result type <id> to be a spirv.ptr, found : ")
955 << type;
956 }
957 wordIndex++;
958
959 // Result <id>.
960 auto variableID = operands[wordIndex];
961 auto variableName = nameMap.lookup(variableID).str();
962 if (variableName.empty()) {
963 variableName = "spirv_var_" + std::to_string(variableID);
964 }
965 wordIndex++;
966
967 // Storage class.
968 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
969 if (ptrType.getStorageClass() != storageClass) {
970 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
971 << type << " and that specified in OpVariable instruction : "
972 << stringifyStorageClass(storageClass);
973 }
974 wordIndex++;
975
976 // Initializer.
977 FlatSymbolRefAttr initializer = nullptr;
978
979 if (wordIndex < operands.size()) {
980 Operation *op = nullptr;
981
982 if (auto initOp = getGlobalVariable(operands[wordIndex]))
983 op = initOp;
984 else if (auto initOp = getSpecConstant(operands[wordIndex]))
985 op = initOp;
986 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
987 op = initOp;
988 else
989 return emitError(unknownLoc, "unknown <id> ")
990 << operands[wordIndex] << "used as initializer";
991
992 initializer = SymbolRefAttr::get(op);
993 wordIndex++;
994 }
995 if (wordIndex != operands.size()) {
996 return emitError(unknownLoc,
997 "found more operands than expected when deserializing "
998 "OpVariable instruction, only ")
999 << wordIndex << " of " << operands.size() << " processed";
1000 }
1001 auto loc = createFileLineColLoc(opBuilder);
1002 auto varOp = spirv::GlobalVariableOp::create(
1003 opBuilder, loc, TypeAttr::get(type),
1004 opBuilder.getStringAttr(variableName), initializer);
1005
1006 // Decorations.
1007 if (decorations.count(variableID)) {
1008 for (auto attr : decorations[variableID].getAttrs())
1009 varOp->setAttr(attr.getName(), attr.getValue());
1010 }
1011 globalVariableMap[variableID] = varOp;
1012 return success();
1013}
1014
1015IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1016 auto constInfo = getConstant(id);
1017 if (!constInfo) {
1018 return nullptr;
1019 }
1020 return dyn_cast<IntegerAttr>(constInfo->first);
1021}
1022
1023LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1024 if (operands.size() < 2) {
1025 return emitError(unknownLoc, "OpName needs at least 2 operands");
1026 }
1027 if (!nameMap.lookup(operands[0]).empty()) {
1028 return emitError(unknownLoc, "duplicate name found for result <id> ")
1029 << operands[0];
1030 }
1031 unsigned wordIndex = 1;
1032 StringRef name = decodeStringLiteral(operands, wordIndex);
1033 if (wordIndex != operands.size()) {
1034 return emitError(unknownLoc,
1035 "unexpected trailing words in OpName instruction");
1036 }
1037 nameMap[operands[0]] = name;
1038 return success();
1039}
1040
1041//===----------------------------------------------------------------------===//
1042// Type
1043//===----------------------------------------------------------------------===//
1044
1045LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1046 ArrayRef<uint32_t> operands) {
1047 if (operands.empty()) {
1048 return emitError(unknownLoc, "type instruction with opcode ")
1049 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1050 }
1051
1052 /// TODO: Types might be forward declared in some instructions and need to be
1053 /// handled appropriately.
1054 if (typeMap.count(operands[0])) {
1055 return emitError(unknownLoc, "duplicate definition for result <id> ")
1056 << operands[0];
1057 }
1058
1059 switch (opcode) {
1060 case spirv::Opcode::OpTypeVoid:
1061 if (operands.size() != 1)
1062 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1063 typeMap[operands[0]] = opBuilder.getNoneType();
1064 break;
1065 case spirv::Opcode::OpTypeBool:
1066 if (operands.size() != 1)
1067 return emitError(unknownLoc, "OpTypeBool must have no parameters");
1068 typeMap[operands[0]] = opBuilder.getI1Type();
1069 break;
1070 case spirv::Opcode::OpTypeInt: {
1071 if (operands.size() != 3)
1072 return emitError(
1073 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1074
1075 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1076 // to preserve or validate.
1077 // 0 indicates unsigned, or no signedness semantics
1078 // 1 indicates signed semantics."
1079 //
1080 // So we cannot differentiate signless and unsigned integers; always use
1081 // signless semantics for such cases.
1082 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1083 : IntegerType::SignednessSemantics::Signless;
1084 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1085 } break;
1086 case spirv::Opcode::OpTypeFloat: {
1087 if (operands.size() != 2 && operands.size() != 3)
1088 return emitError(unknownLoc,
1089 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1090 "or 3 operands (type, bitwidth, encoding), but got ")
1091 << operands.size();
1092 uint32_t bitWidth = operands[1];
1093
1094 Type floatTy;
1095 switch (bitWidth) {
1096 case 16:
1097 floatTy = opBuilder.getF16Type();
1098 break;
1099 case 32:
1100 floatTy = opBuilder.getF32Type();
1101 break;
1102 case 64:
1103 floatTy = opBuilder.getF64Type();
1104 break;
1105 default:
1106 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1107 << bitWidth;
1108 }
1109
1110 if (operands.size() == 3) {
1111 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1112 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1113 << operands[2];
1114 if (bitWidth != 16)
1115 return emitError(unknownLoc,
1116 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1117 << bitWidth << " (expected 16)";
1118 floatTy = opBuilder.getBF16Type();
1119 }
1120
1121 typeMap[operands[0]] = floatTy;
1122 } break;
1123 case spirv::Opcode::OpTypeVector: {
1124 if (operands.size() != 3) {
1125 return emitError(
1126 unknownLoc,
1127 "OpTypeVector must have element type and count parameters");
1128 }
1129 Type elementTy = getType(operands[1]);
1130 if (!elementTy) {
1131 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1132 << operands[1];
1133 }
1134 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1135 } break;
1136 case spirv::Opcode::OpTypePointer: {
1137 return processOpTypePointer(operands);
1138 } break;
1139 case spirv::Opcode::OpTypeArray:
1140 return processArrayType(operands);
1141 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1142 return processCooperativeMatrixTypeKHR(operands);
1143 case spirv::Opcode::OpTypeFunction:
1144 return processFunctionType(operands);
1145 case spirv::Opcode::OpTypeImage:
1146 return processImageType(operands);
1147 case spirv::Opcode::OpTypeSampledImage:
1148 return processSampledImageType(operands);
1149 case spirv::Opcode::OpTypeRuntimeArray:
1150 return processRuntimeArrayType(operands);
1151 case spirv::Opcode::OpTypeStruct:
1152 return processStructType(operands);
1153 case spirv::Opcode::OpTypeMatrix:
1154 return processMatrixType(operands);
1155 case spirv::Opcode::OpTypeTensorARM:
1156 return processTensorARMType(operands);
1157 case spirv::Opcode::OpTypeGraphARM:
1158 return processGraphTypeARM(operands);
1159 default:
1160 return emitError(unknownLoc, "unhandled type instruction");
1161 }
1162 return success();
1163}
1164
1165LogicalResult
1167 if (operands.size() != 3)
1168 return emitError(unknownLoc, "OpTypePointer must have two parameters");
1169
1170 auto pointeeType = getType(operands[2]);
1171 if (!pointeeType)
1172 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1173 << operands[2];
1174
1175 uint32_t typePointerID = operands[0];
1176 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1177 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1178
1179 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1180 deferredStructIt != std::end(deferredStructTypesInfos);) {
1181 for (auto *unresolvedMemberIt =
1182 std::begin(deferredStructIt->unresolvedMemberTypes);
1183 unresolvedMemberIt !=
1184 std::end(deferredStructIt->unresolvedMemberTypes);) {
1185 if (unresolvedMemberIt->first == typePointerID) {
1186 // The newly constructed pointer type can resolve one of the
1187 // deferred struct type members; update the memberTypes list and
1188 // clean the unresolvedMemberTypes list accordingly.
1189 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1190 typeMap[typePointerID];
1191 unresolvedMemberIt =
1192 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1193 } else {
1194 ++unresolvedMemberIt;
1195 }
1196 }
1197
1198 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1199 // All deferred struct type members are now resolved, set the struct body.
1200 auto structType = deferredStructIt->deferredStructType;
1201
1202 assert(structType && "expected a spirv::StructType");
1203 assert(structType.isIdentified() && "expected an indentified struct");
1204
1205 if (failed(structType.trySetBody(
1206 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1207 deferredStructIt->memberDecorationsInfo,
1208 deferredStructIt->structDecorationsInfo)))
1209 return failure();
1210
1211 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1212 } else {
1213 ++deferredStructIt;
1214 }
1215 }
1216
1217 return success();
1218}
1219
1220LogicalResult
1222 if (operands.size() != 3) {
1223 return emitError(unknownLoc,
1224 "OpTypeArray must have element type and count parameters");
1225 }
1226
1227 Type elementTy = getType(operands[1]);
1228 if (!elementTy) {
1229 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1230 << operands[1];
1231 }
1232
1233 unsigned count = 0;
1234 // TODO: The count can also come frome a specialization constant.
1235 auto countInfo = getConstant(operands[2]);
1236 if (!countInfo) {
1237 return emitError(unknownLoc, "OpTypeArray count <id> ")
1238 << operands[2] << "can only come from normal constant right now";
1239 }
1240
1241 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1242 count = intVal.getValue().getZExtValue();
1243 } else {
1244 return emitError(unknownLoc, "OpTypeArray count must come from a "
1245 "scalar integer constant instruction");
1246 }
1247
1248 typeMap[operands[0]] = spirv::ArrayType::get(
1249 elementTy, count, typeDecorations.lookup(operands[0]));
1250 return success();
1251}
1252
1253LogicalResult
1255 assert(!operands.empty() && "No operands for processing function type");
1256 if (operands.size() == 1) {
1257 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1258 }
1259 auto returnType = getType(operands[1]);
1260 if (!returnType) {
1261 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1262 }
1263 SmallVector<Type, 1> argTypes;
1264 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1265 auto ty = getType(operands[i]);
1266 if (!ty) {
1267 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1268 }
1269 argTypes.push_back(ty);
1270 }
1271 ArrayRef<Type> returnTypes;
1272 if (!isVoidType(returnType)) {
1273 returnTypes = llvm::ArrayRef(returnType);
1274 }
1275 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1276 return success();
1277}
1278
1280 ArrayRef<uint32_t> operands) {
1281 if (operands.size() != 6) {
1282 return emitError(unknownLoc,
1283 "OpTypeCooperativeMatrixKHR must have element type, "
1284 "scope, row and column parameters, and use");
1285 }
1286
1287 Type elementTy = getType(operands[1]);
1288 if (!elementTy) {
1289 return emitError(unknownLoc,
1290 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1291 << operands[1];
1292 }
1293
1294 std::optional<spirv::Scope> scope =
1295 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1296 if (!scope) {
1297 return emitError(
1298 unknownLoc,
1299 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1300 << operands[2];
1301 }
1302
1303 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1304 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1305 IntegerAttr useAttr = getConstantInt(operands[5]);
1306
1307 if (!rowsAttr)
1308 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1309 "undefined constant <id> ")
1310 << operands[3];
1311
1312 if (!columnsAttr)
1313 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1314 "references undefined constant <id> ")
1315 << operands[4];
1316
1317 if (!useAttr)
1318 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1319 "undefined constant <id> ")
1320 << operands[5];
1321
1322 unsigned rows = rowsAttr.getInt();
1323 unsigned columns = columnsAttr.getInt();
1324
1325 std::optional<spirv::CooperativeMatrixUseKHR> use =
1326 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1327 if (!use) {
1328 return emitError(
1329 unknownLoc,
1330 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1331 << operands[5];
1332 }
1333
1334 typeMap[operands[0]] =
1335 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1336 return success();
1337}
1338
1339LogicalResult
1341 if (operands.size() != 2) {
1342 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1343 }
1344 Type memberType = getType(operands[1]);
1345 if (!memberType) {
1346 return emitError(unknownLoc,
1347 "OpTypeRuntimeArray references undefined <id> ")
1348 << operands[1];
1349 }
1350 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1351 memberType, typeDecorations.lookup(operands[0]));
1352 return success();
1353}
1354
1355LogicalResult
1357 // TODO: Find a way to handle identified structs when debug info is stripped.
1358
1359 if (operands.empty()) {
1360 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1361 }
1362
1363 if (operands.size() == 1) {
1364 // Handle empty struct.
1365 typeMap[operands[0]] =
1366 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1367 return success();
1368 }
1369
1370 // First element is operand ID, second element is member index in the struct.
1371 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1372 SmallVector<Type, 4> memberTypes;
1373
1374 for (auto op : llvm::drop_begin(operands, 1)) {
1375 Type memberType = getType(op);
1376 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1377
1378 if (!memberType && !typeForwardPtr)
1379 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1380 << op;
1381
1382 if (!memberType)
1383 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1384
1385 memberTypes.push_back(memberType);
1386 }
1387
1390 if (memberDecorationMap.count(operands[0])) {
1391 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1392 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1393 if (allMemberDecorations.count(memberIndex)) {
1394 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1395 // Check for offset.
1396 if (memberDecoration.first == spirv::Decoration::Offset) {
1397 // If offset info is empty, resize to the number of members;
1398 if (offsetInfo.empty()) {
1399 offsetInfo.resize(memberTypes.size());
1400 }
1401 offsetInfo[memberIndex] = memberDecoration.second[0];
1402 } else {
1403 auto intType = mlir::IntegerType::get(context, 32);
1404 if (!memberDecoration.second.empty()) {
1405 memberDecorationsInfo.emplace_back(
1406 memberIndex, memberDecoration.first,
1407 IntegerAttr::get(intType, memberDecoration.second[0]));
1408 } else {
1409 memberDecorationsInfo.emplace_back(
1410 memberIndex, memberDecoration.first, UnitAttr::get(context));
1411 }
1412 }
1413 }
1414 }
1415 }
1416 }
1417
1419 if (decorations.count(operands[0])) {
1420 NamedAttrList &allDecorations = decorations[operands[0]];
1421 for (NamedAttribute &decorationAttr : allDecorations) {
1422 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1423 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1424 assert(decoration.has_value());
1425 structDecorationsInfo.emplace_back(decoration.value(),
1426 decorationAttr.getValue());
1427 }
1428 }
1429
1430 uint32_t structID = operands[0];
1431 std::string structIdentifier = nameMap.lookup(structID).str();
1432
1433 if (structIdentifier.empty()) {
1434 assert(unresolvedMemberTypes.empty() &&
1435 "didn't expect unresolved member types");
1436 typeMap[structID] = spirv::StructType::get(
1437 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1438 } else {
1439 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1440 typeMap[structID] = structTy;
1441
1442 if (!unresolvedMemberTypes.empty())
1443 deferredStructTypesInfos.push_back(
1444 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1445 memberDecorationsInfo, structDecorationsInfo});
1446 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1447 memberDecorationsInfo,
1448 structDecorationsInfo)))
1449 return failure();
1450 }
1451
1452 // TODO: Update StructType to have member name as attribute as
1453 // well.
1454 return success();
1455}
1456
1457LogicalResult
1459 if (operands.size() != 3) {
1460 // Three operands are needed: result_id, column_type, and column_count
1461 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1462 " (result_id, column_type, and column_count)");
1463 }
1464 // Matrix columns must be of vector type
1465 Type elementTy = getType(operands[1]);
1466 if (!elementTy) {
1467 return emitError(unknownLoc,
1468 "OpTypeMatrix references undefined column type.")
1469 << operands[1];
1470 }
1471
1472 uint32_t colsCount = operands[2];
1473 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1474 return success();
1475}
1476
1477LogicalResult
1479 unsigned size = operands.size();
1480 if (size < 2 || size > 4)
1481 return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1482 "(result_id, element_type, (rank), (shape)) ")
1483 << size;
1484
1485 Type elementTy = getType(operands[1]);
1486 if (!elementTy)
1487 return emitError(unknownLoc,
1488 "OpTypeTensorARM references undefined element type ")
1489 << operands[1];
1490
1491 if (size == 2) {
1492 typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1493 return success();
1494 }
1495
1496 IntegerAttr rankAttr = getConstantInt(operands[2]);
1497 if (!rankAttr)
1498 return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1499 "scalar integer constant instruction");
1500 unsigned rank = rankAttr.getValue().getZExtValue();
1501 if (size == 3) {
1502 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1503 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1504 return success();
1505 }
1506
1507 std::optional<std::pair<Attribute, Type>> shapeInfo =
1508 getConstant(operands[3]);
1509 if (!shapeInfo)
1510 return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1511 "constant instruction of type OpTypeArray");
1512
1513 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1515 for (auto dimAttr : shapeArrayAttr.getValue()) {
1516 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1517 if (!dimIntAttr)
1518 return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1519 "dimension size");
1520 shape.push_back(dimIntAttr.getValue().getSExtValue());
1521 }
1522 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1523 return success();
1524}
1525
1526LogicalResult
1528 unsigned size = operands.size();
1529 if (size < 2) {
1530 return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1531 "(result_id, num_inputs, (inout0_type, "
1532 "inout1_type, ...))")
1533 << size;
1534 }
1535 uint32_t numInputs = operands[1];
1536 SmallVector<Type, 1> argTypes;
1537 SmallVector<Type, 1> returnTypes;
1538 for (unsigned i = 2; i < size; ++i) {
1539 Type inOutTy = getType(operands[i]);
1540 if (!inOutTy) {
1541 return emitError(unknownLoc,
1542 "OpTypeGraphARM references undefined element type.")
1543 << operands[i];
1544 }
1545 if (i - 2 >= numInputs) {
1546 returnTypes.push_back(inOutTy);
1547 } else {
1548 argTypes.push_back(inOutTy);
1549 }
1550 }
1551 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1552 return success();
1553}
1554
1555LogicalResult
1557 if (operands.size() != 2)
1558 return emitError(unknownLoc,
1559 "OpTypeForwardPointer instruction must have two operands");
1560
1561 typeForwardPointerIDs.insert(operands[0]);
1562 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1563 // instruction that defines the actual type.
1564
1565 return success();
1566}
1567
1568LogicalResult
1570 // TODO: Add support for Access Qualifier.
1571 if (operands.size() != 8)
1572 return emitError(
1573 unknownLoc,
1574 "OpTypeImage with non-eight operands are not supported yet");
1575
1576 Type elementTy = getType(operands[1]);
1577 if (!elementTy)
1578 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1579 << operands[1];
1580
1581 auto dim = spirv::symbolizeDim(operands[2]);
1582 if (!dim)
1583 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1584 << operands[2];
1585
1586 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1587 if (!depthInfo)
1588 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1589 << operands[3];
1590
1591 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1592 if (!arrayedInfo)
1593 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1594 << operands[4];
1595
1596 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1597 if (!samplingInfo)
1598 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1599
1600 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1601 if (!samplerUseInfo)
1602 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1603 << operands[6];
1604
1605 auto format = spirv::symbolizeImageFormat(operands[7]);
1606 if (!format)
1607 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1608 << operands[7];
1609
1610 typeMap[operands[0]] = spirv::ImageType::get(
1611 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1612 samplingInfo.value(), samplerUseInfo.value(), format.value());
1613 return success();
1614}
1615
1616LogicalResult
1618 if (operands.size() != 2)
1619 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1620
1621 Type elementTy = getType(operands[1]);
1622 if (!elementTy)
1623 return emitError(unknownLoc,
1624 "OpTypeSampledImage references undefined <id>: ")
1625 << operands[1];
1626
1627 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1628 return success();
1629}
1630
1631//===----------------------------------------------------------------------===//
1632// Constant
1633//===----------------------------------------------------------------------===//
1634
1636 bool isSpec) {
1637 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1638
1639 if (operands.size() < 2) {
1640 return emitError(unknownLoc)
1641 << opname << " must have type <id> and result <id>";
1642 }
1643 if (operands.size() < 3) {
1644 return emitError(unknownLoc)
1645 << opname << " must have at least 1 more parameter";
1646 }
1647
1648 Type resultType = getType(operands[0]);
1649 if (!resultType) {
1650 return emitError(unknownLoc, "undefined result type from <id> ")
1651 << operands[0];
1652 }
1653
1654 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1655 if (bitwidth == 64) {
1656 if (operands.size() == 4) {
1657 return success();
1658 }
1659 return emitError(unknownLoc)
1660 << opname << " should have 2 parameters for 64-bit values";
1661 }
1662 if (bitwidth <= 32) {
1663 if (operands.size() == 3) {
1664 return success();
1665 }
1666
1667 return emitError(unknownLoc)
1668 << opname
1669 << " should have 1 parameter for values with no more than 32 bits";
1670 }
1671 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1672 << bitwidth;
1673 };
1674
1675 auto resultID = operands[1];
1676
1677 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1678 auto bitwidth = intType.getWidth();
1679 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1680 return failure();
1681 }
1682
1683 APInt value;
1684 if (bitwidth == 64) {
1685 // 64-bit integers are represented with two SPIR-V words. According to
1686 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1687 // literal’s low-order words appear first."
1688 struct DoubleWord {
1689 uint32_t word1;
1690 uint32_t word2;
1691 } words = {operands[2], operands[3]};
1692 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1693 } else if (bitwidth <= 32) {
1694 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1695 /*implicitTrunc=*/true);
1696 }
1697
1698 auto attr = opBuilder.getIntegerAttr(intType, value);
1699
1700 if (isSpec) {
1701 createSpecConstant(unknownLoc, resultID, attr);
1702 } else {
1703 // For normal constants, we just record the attribute (and its type) for
1704 // later materialization at use sites.
1705 constantMap.try_emplace(resultID, attr, intType);
1706 }
1707
1708 return success();
1709 }
1710
1711 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1712 auto bitwidth = floatType.getWidth();
1713 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1714 return failure();
1715 }
1716
1717 APFloat value(0.f);
1718 if (floatType.isF64()) {
1719 // Double values are represented with two SPIR-V words. According to
1720 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1721 // literal’s low-order words appear first."
1722 struct DoubleWord {
1723 uint32_t word1;
1724 uint32_t word2;
1725 } words = {operands[2], operands[3]};
1726 value = APFloat(llvm::bit_cast<double>(words));
1727 } else if (floatType.isF32()) {
1728 value = APFloat(llvm::bit_cast<float>(operands[2]));
1729 } else if (floatType.isF16()) {
1730 APInt data(16, operands[2]);
1731 value = APFloat(APFloat::IEEEhalf(), data);
1732 } else if (floatType.isBF16()) {
1733 APInt data(16, operands[2]);
1734 value = APFloat(APFloat::BFloat(), data);
1735 }
1736
1737 auto attr = opBuilder.getFloatAttr(floatType, value);
1738 if (isSpec) {
1739 createSpecConstant(unknownLoc, resultID, attr);
1740 } else {
1741 // For normal constants, we just record the attribute (and its type) for
1742 // later materialization at use sites.
1743 constantMap.try_emplace(resultID, attr, floatType);
1744 }
1745
1746 return success();
1747 }
1748
1749 return emitError(unknownLoc, "OpConstant can only generate values of "
1750 "scalar integer or floating-point type");
1751}
1752
1754 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1755 if (operands.size() != 2) {
1756 return emitError(unknownLoc, "Op")
1757 << (isSpec ? "Spec" : "") << "Constant"
1758 << (isTrue ? "True" : "False")
1759 << " must have type <id> and result <id>";
1760 }
1761
1762 auto attr = opBuilder.getBoolAttr(isTrue);
1763 auto resultID = operands[1];
1764 if (isSpec) {
1765 createSpecConstant(unknownLoc, resultID, attr);
1766 } else {
1767 // For normal constants, we just record the attribute (and its type) for
1768 // later materialization at use sites.
1769 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1770 }
1771
1772 return success();
1773}
1774
1775LogicalResult
1777 if (operands.size() < 2) {
1778 return emitError(unknownLoc,
1779 "OpConstantComposite must have type <id> and result <id>");
1780 }
1781 if (operands.size() < 3) {
1782 return emitError(unknownLoc,
1783 "OpConstantComposite must have at least 1 parameter");
1784 }
1785
1786 Type resultType = getType(operands[0]);
1787 if (!resultType) {
1788 return emitError(unknownLoc, "undefined result type from <id> ")
1789 << operands[0];
1790 }
1791
1793 elements.reserve(operands.size() - 2);
1794 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1795 auto elementInfo = getConstant(operands[i]);
1796 if (!elementInfo) {
1797 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1798 << operands[i] << " must come from a normal constant";
1799 }
1800 elements.push_back(elementInfo->first);
1801 }
1802
1803 auto resultID = operands[1];
1804 if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1805 SmallVector<Attribute> flattenedElems;
1806 for (Attribute element : elements) {
1807 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1808 for (auto value : denseElemAttr.getValues<Attribute>())
1809 flattenedElems.push_back(value);
1810 } else {
1811 flattenedElems.push_back(element);
1812 }
1813 }
1814 auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1815 constantMap.try_emplace(resultID, attr, tensorType);
1816 } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1817 auto attr = DenseElementsAttr::get(shapedType, elements);
1818 // For normal constants, we just record the attribute (and its type) for
1819 // later materialization at use sites.
1820 constantMap.try_emplace(resultID, attr, shapedType);
1821 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1822 auto attr = opBuilder.getArrayAttr(elements);
1823 constantMap.try_emplace(resultID, attr, resultType);
1824 } else {
1825 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1826 << resultType;
1827 }
1828
1829 return success();
1830}
1831
1833 ArrayRef<uint32_t> operands) {
1834 if (operands.size() != 3) {
1835 return emitError(
1836 unknownLoc,
1837 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1838 << operands.size();
1839 }
1840
1841 Type resultType = getType(operands[0]);
1842 if (!resultType) {
1843 return emitError(unknownLoc, "undefined result type from <id> ")
1844 << operands[0];
1845 }
1846
1847 auto compositeType = dyn_cast<CompositeType>(resultType);
1848 if (!compositeType) {
1849 return emitError(unknownLoc,
1850 "result type from <id> is not a composite type")
1851 << operands[0];
1852 }
1853
1854 uint32_t resultID = operands[1];
1855 uint32_t constantID = operands[2];
1856
1857 std::optional<std::pair<Attribute, Type>> constantInfo =
1858 getConstant(constantID);
1859 if (constantInfo.has_value()) {
1860 constantCompositeReplicateMap.try_emplace(
1861 resultID, constantInfo.value().first, resultType);
1862 return success();
1863 }
1864
1865 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1867 if (replicatedConstantCompositeInfo.has_value()) {
1868 constantCompositeReplicateMap.try_emplace(
1869 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1870 return success();
1871 }
1872
1873 return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1874 << constantID
1875 << " must come from a normal constant or a "
1876 "OpConstantCompositeReplicateEXT";
1877}
1878
1879LogicalResult
1881 if (operands.size() < 2) {
1882 return emitError(
1883 unknownLoc,
1884 "OpSpecConstantComposite must have type <id> and result <id>");
1885 }
1886 if (operands.size() < 3) {
1887 return emitError(unknownLoc,
1888 "OpSpecConstantComposite must have at least 1 parameter");
1889 }
1890
1891 Type resultType = getType(operands[0]);
1892 if (!resultType) {
1893 return emitError(unknownLoc, "undefined result type from <id> ")
1894 << operands[0];
1895 }
1896
1897 auto resultID = operands[1];
1898 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1899
1901 elements.reserve(operands.size() - 2);
1902 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1903 auto elementInfo = getSpecConstant(operands[i]);
1904 elements.push_back(SymbolRefAttr::get(elementInfo));
1905 }
1906
1907 auto op = spirv::SpecConstantCompositeOp::create(
1908 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1909 opBuilder.getArrayAttr(elements));
1910 specConstCompositeMap[resultID] = op;
1911
1912 return success();
1913}
1914
1916 ArrayRef<uint32_t> operands) {
1917 if (operands.size() != 3) {
1918 return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1919 "3 operands but found ")
1920 << operands.size();
1921 }
1922
1923 Type resultType = getType(operands[0]);
1924 if (!resultType) {
1925 return emitError(unknownLoc, "undefined result type from <id> ")
1926 << operands[0];
1927 }
1928
1929 auto compositeType = dyn_cast<CompositeType>(resultType);
1930 if (!compositeType) {
1931 return emitError(unknownLoc,
1932 "result type from <id> is not a composite type")
1933 << operands[0];
1934 }
1935
1936 uint32_t resultID = operands[1];
1937
1938 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1939 spirv::SpecConstantOp constituentSpecConstantOp =
1940 getSpecConstant(operands[2]);
1941 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1942 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1943 SymbolRefAttr::get(constituentSpecConstantOp));
1944
1945 specConstCompositeReplicateMap[resultID] = op;
1946
1947 return success();
1948}
1949
1950LogicalResult
1952 if (operands.size() < 3)
1953 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1954 "result <id>, and operand opcode");
1955
1956 uint32_t resultTypeID = operands[0];
1957
1958 if (!getType(resultTypeID))
1959 return emitError(unknownLoc, "undefined result type from <id> ")
1960 << resultTypeID;
1961
1962 uint32_t resultID = operands[1];
1963 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1964 auto emplaceResult = specConstOperationMap.try_emplace(
1965 resultID,
1967 enclosedOpcode, resultTypeID,
1968 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1969
1970 if (!emplaceResult.second)
1971 return emitError(unknownLoc, "value with <id>: ")
1972 << resultID << " is probably defined before.";
1973
1974 return success();
1975}
1976
1978 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1979 ArrayRef<uint32_t> enclosedOpOperands) {
1980
1981 Type resultType = getType(resultTypeID);
1982
1983 // Instructions wrapped by OpSpecConstantOp need an ID for their
1984 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1985 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1986 // ID in that map is assigned to the result of the enclosed instruction. Note
1987 // that there is no need to update this fake ID since we only need to
1988 // reference the created Value for the enclosed op from the spv::YieldOp
1989 // created later in this method (both of which are the only values in their
1990 // region: the SpecConstantOperation's region). If we encounter another
1991 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1992 // previous Value assigned to it isn't visible in the current scope anyway.
1993 DenseMap<uint32_t, Value> newValueMap;
1994 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1995 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1996
1997 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1998 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1999 enclosedOpResultTypeAndOperands.push_back(fakeID);
2000 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2001 enclosedOpOperands.end());
2002
2003 // Process enclosed instruction before creating the enclosing
2004 // specConstantOperation (and its region). This way, references to constants,
2005 // global variables, and spec constants will be materialized outside the new
2006 // op's region. For more info, see Deserializer::getValue's implementation.
2007 if (failed(
2008 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2009 return Value();
2010
2011 // Since the enclosed op is emitted in the current block, split it in a
2012 // separate new block.
2013 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2014
2015 auto loc = createFileLineColLoc(opBuilder);
2016 auto specConstOperationOp =
2017 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2018
2019 Region &body = specConstOperationOp.getBody();
2020 // Move the new block into SpecConstantOperation's body.
2021 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2022 Region::iterator(enclosedBlock));
2023 Block &block = body.back();
2024
2025 // RAII guard to reset the insertion point to the module's region after
2026 // deserializing the body of the specConstantOperation.
2027 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2028 opBuilder.setInsertionPointToEnd(&block);
2029
2030 spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2031 return specConstOperationOp.getResult();
2032}
2033
2034LogicalResult
2036 if (operands.size() != 2) {
2037 return emitError(unknownLoc,
2038 "OpConstantNull must only have type <id> and result <id>");
2039 }
2040
2041 Type resultType = getType(operands[0]);
2042 if (!resultType) {
2043 return emitError(unknownLoc, "undefined result type from <id> ")
2044 << operands[0];
2045 }
2046
2047 auto resultID = operands[1];
2048 Attribute attr;
2049 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2050 attr = opBuilder.getZeroAttr(resultType);
2051 } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2052 if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2053 attr = DenseElementsAttr::get(tensorType, element);
2054 }
2055
2056 if (attr) {
2057 // For normal constants, we just record the attribute (and its type) for
2058 // later materialization at use sites.
2059 constantMap.try_emplace(resultID, attr, resultType);
2060 return success();
2061 }
2062
2063 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2064 << resultType;
2065}
2066
2067LogicalResult
2069 if (operands.size() < 3) {
2070 return emitError(unknownLoc)
2071 << "OpGraphConstantARM must have at least 2 operands";
2072 }
2073
2074 Type resultType = getType(operands[0]);
2075 if (!resultType) {
2076 return emitError(unknownLoc, "undefined result type from <id> ")
2077 << operands[0];
2078 }
2079
2080 uint32_t resultID = operands[1];
2081
2082 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083 return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2084 }
2085
2086 APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2087 Type i32Ty = opBuilder.getIntegerType(32);
2088 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2089 graphConstantMap.try_emplace(
2090 resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2091
2092 return success();
2093}
2094
2095//===----------------------------------------------------------------------===//
2096// Control flow
2097//===----------------------------------------------------------------------===//
2098
2100 if (auto *block = getBlock(id)) {
2101 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2102 << " @ " << block << "\n");
2103 return block;
2104 }
2105
2106 // We don't know where this block will be placed finally (in a
2107 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2108 // function for now and sort out the proper place later.
2109 auto *block = curFunction->addBlock();
2110 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2111 << " @ " << block << "\n");
2112 return blockMap[id] = block;
2113}
2114
2116 if (!curBlock) {
2117 return emitError(unknownLoc, "OpBranch must appear inside a block");
2118 }
2119
2120 if (operands.size() != 1) {
2121 return emitError(unknownLoc, "OpBranch must take exactly one target label");
2122 }
2123
2124 auto *target = getOrCreateBlock(operands[0]);
2125 auto loc = createFileLineColLoc(opBuilder);
2126 // The preceding instruction for the OpBranch instruction could be an
2127 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2128 // the same OpLine information.
2129 spirv::BranchOp::create(opBuilder, loc, target);
2130
2132 return success();
2133}
2134
2135LogicalResult
2137 if (!curBlock) {
2138 return emitError(unknownLoc,
2139 "OpBranchConditional must appear inside a block");
2140 }
2141
2142 if (operands.size() != 3 && operands.size() != 5) {
2143 return emitError(unknownLoc,
2144 "OpBranchConditional must have condition, true label, "
2145 "false label, and optionally two branch weights");
2146 }
2147
2148 auto condition = getValue(operands[0]);
2149 auto *trueBlock = getOrCreateBlock(operands[1]);
2150 auto *falseBlock = getOrCreateBlock(operands[2]);
2151
2152 std::optional<std::pair<uint32_t, uint32_t>> weights;
2153 if (operands.size() == 5) {
2154 weights = std::make_pair(operands[3], operands[4]);
2155 }
2156 // The preceding instruction for the OpBranchConditional instruction could be
2157 // an OpSelectionMerge instruction, in this case they will have the same
2158 // OpLine information.
2159 auto loc = createFileLineColLoc(opBuilder);
2160 spirv::BranchConditionalOp::create(
2161 opBuilder, loc, condition, trueBlock,
2162 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2163 /*falseArguments=*/ArrayRef<Value>(), weights);
2164
2166 return success();
2167}
2168
2170 if (!curFunction) {
2171 return emitError(unknownLoc, "OpLabel must appear inside a function");
2172 }
2173
2174 if (operands.size() != 1) {
2175 return emitError(unknownLoc, "OpLabel should only have result <id>");
2176 }
2177
2178 auto labelID = operands[0];
2179 // We may have forward declared this block.
2180 auto *block = getOrCreateBlock(labelID);
2181 LLVM_DEBUG(logger.startLine()
2182 << "[block] populating block " << block << "\n");
2183 // If we have seen this block, make sure it was just a forward declaration.
2184 assert(block->empty() && "re-deserialize the same block!");
2185
2186 opBuilder.setInsertionPointToStart(block);
2187 blockMap[labelID] = curBlock = block;
2188
2189 return success();
2190}
2191
2192LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2193 if (!curGraph) {
2194 return emitError(unknownLoc, "a graph block must appear inside a graph");
2195 }
2196
2197 // We may have forward declared this block.
2198 Block *block = getOrCreateBlock(graphID);
2199 LLVM_DEBUG(logger.startLine()
2200 << "[block] populating block " << block << "\n");
2201 // If we have seen this block, make sure it was just a forward declaration.
2202 assert(block->empty() && "re-deserialize the same block!");
2203
2204 opBuilder.setInsertionPointToStart(block);
2205 blockMap[graphID] = curBlock = block;
2206
2207 return success();
2208}
2209
2210LogicalResult
2212 if (!curBlock) {
2213 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2214 }
2215
2216 if (operands.size() < 2) {
2217 return emitError(
2218 unknownLoc,
2219 "OpSelectionMerge must specify merge target and selection control");
2220 }
2221
2222 auto *mergeBlock = getOrCreateBlock(operands[0]);
2223 auto loc = createFileLineColLoc(opBuilder);
2224 auto selectionControl = operands[1];
2225
2226 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2227 .second) {
2228 return emitError(
2229 unknownLoc,
2230 "a block cannot have more than one OpSelectionMerge instruction");
2231 }
2232
2233 return success();
2234}
2235
2236LogicalResult
2238 if (!curBlock) {
2239 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2240 }
2241
2242 if (operands.size() < 3) {
2243 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2244 "continue target and loop control");
2245 }
2246
2247 auto *mergeBlock = getOrCreateBlock(operands[0]);
2248 auto *continueBlock = getOrCreateBlock(operands[1]);
2249 auto loc = createFileLineColLoc(opBuilder);
2250 uint32_t loopControl = operands[2];
2251
2252 if (!blockMergeInfo
2253 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2254 .second) {
2255 return emitError(
2256 unknownLoc,
2257 "a block cannot have more than one OpLoopMerge instruction");
2258 }
2259
2260 return success();
2261}
2262
2264 if (!curBlock) {
2265 return emitError(unknownLoc, "OpPhi must appear in a block");
2266 }
2267
2268 if (operands.size() < 4) {
2269 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2270 "and variable-parent pairs");
2271 }
2272
2273 // Create a block argument for this OpPhi instruction.
2274 Type blockArgType = getType(operands[0]);
2275 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2276 valueMap[operands[1]] = blockArg;
2277 LLVM_DEBUG(logger.startLine()
2278 << "[phi] created block argument " << blockArg
2279 << " id = " << operands[1] << " of type " << blockArgType << "\n");
2280
2281 // For each (value, predecessor) pair, insert the value to the predecessor's
2282 // blockPhiInfo entry so later we can fix the block argument there.
2283 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2284 uint32_t value = operands[i];
2285 Block *predecessor = getOrCreateBlock(operands[i + 1]);
2286 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2287 blockPhiInfo[predecessorTargetPair].push_back(value);
2288 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2289 << " with arg id = " << value << "\n");
2290 }
2291
2292 return success();
2293}
2294
2296 if (!curBlock)
2297 return emitError(unknownLoc, "OpSwitch must appear in a block");
2298
2299 if (operands.size() < 2)
2300 return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2301 "a default target");
2302
2303 if (operands.size() % 2)
2304 return emitError(unknownLoc,
2305 "OpSwitch must at have an even number of operands: "
2306 "selector, default target and any number of literal and "
2307 "label <id> pairs");
2308
2309 Value selector = getValue(operands[0]);
2310 Block *defaultBlock = getOrCreateBlock(operands[1]);
2311 Location loc = createFileLineColLoc(opBuilder);
2312
2313 SmallVector<int32_t> literals;
2314 SmallVector<Block *> blocks;
2315 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2316 literals.push_back(operands[i]);
2317 blocks.push_back(getOrCreateBlock(operands[i + 1]));
2318 }
2319
2320 SmallVector<ValueRange> targetOperands(blocks.size(), {});
2321 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2322 ArrayRef<Value>(), literals, blocks, targetOperands);
2323
2324 return success();
2325}
2326
2327namespace {
2328/// A class for putting all blocks in a structured selection/loop in a
2329/// spirv.mlir.selection/spirv.mlir.loop op.
2330class ControlFlowStructurizer {
2331public:
2332#ifndef NDEBUG
2333 ControlFlowStructurizer(Location loc, uint32_t control,
2334 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2335 Block *merge, Block *cont,
2336 llvm::ScopedPrinter &logger)
2337 : location(loc), control(control), blockMergeInfo(mergeInfo),
2338 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2339 logger(logger) {}
2340#else
2341 ControlFlowStructurizer(Location loc, uint32_t control,
2342 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2343 Block *merge, Block *cont)
2344 : location(loc), control(control), blockMergeInfo(mergeInfo),
2345 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2346#endif
2347
2348 /// Structurizes the loop at the given `headerBlock`.
2349 ///
2350 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2351 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2352 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2353 /// method will also update `mergeInfo` by remapping all blocks inside to the
2354 /// newly cloned ones inside structured control flow op's regions.
2355 LogicalResult structurize();
2356
2357private:
2358 /// Creates a new spirv.mlir.selection op at the beginning of the
2359 /// `mergeBlock`.
2360 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2361
2362 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2363 spirv::LoopOp createLoopOp(uint32_t loopControl);
2364
2365 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2366 void collectBlocksInConstruct();
2367
2368 Location location;
2369 uint32_t control;
2370
2371 spirv::BlockMergeInfoMap &blockMergeInfo;
2372
2373 Block *headerBlock;
2374 Block *mergeBlock;
2375 Block *continueBlock; // nullptr for spirv.mlir.selection
2376
2377 SetVector<Block *> constructBlocks;
2378
2379#ifndef NDEBUG
2380 /// A logger used to emit information during the deserialzation process.
2381 llvm::ScopedPrinter &logger;
2382#endif
2383};
2384} // namespace
2385
2386spirv::SelectionOp
2387ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2388 // Create a builder and set the insertion point to the beginning of the
2389 // merge block so that the newly created SelectionOp will be inserted there.
2390 OpBuilder builder(&mergeBlock->front());
2391
2392 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2393 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2394 selectionOp.addMergeBlock(builder);
2395
2396 return selectionOp;
2397}
2398
2399spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2400 // Create a builder and set the insertion point to the beginning of the
2401 // merge block so that the newly created LoopOp will be inserted there.
2402 OpBuilder builder(&mergeBlock->front());
2403
2404 auto control = static_cast<spirv::LoopControl>(loopControl);
2405 auto loopOp = spirv::LoopOp::create(builder, location, control);
2406 loopOp.addEntryAndMergeBlock(builder);
2407
2408 return loopOp;
2409}
2410
2411void ControlFlowStructurizer::collectBlocksInConstruct() {
2412 assert(constructBlocks.empty() && "expected empty constructBlocks");
2413
2414 // Put the header block in the work list first.
2415 constructBlocks.insert(headerBlock);
2416
2417 // For each item in the work list, add its successors excluding the merge
2418 // block.
2419 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2420 for (auto *successor : constructBlocks[i]->getSuccessors())
2421 if (successor != mergeBlock)
2422 constructBlocks.insert(successor);
2423 }
2424}
2425
2426LogicalResult ControlFlowStructurizer::structurize() {
2427 Operation *op = nullptr;
2428 bool isLoop = continueBlock != nullptr;
2429 if (isLoop) {
2430 if (auto loopOp = createLoopOp(control))
2431 op = loopOp.getOperation();
2432 } else {
2433 if (auto selectionOp = createSelectionOp(control))
2434 op = selectionOp.getOperation();
2435 }
2436 if (!op)
2437 return failure();
2438 Region &body = op->getRegion(0);
2439
2440 IRMapping mapper;
2441 // All references to the old merge block should be directed to the
2442 // selection/loop merge block in the SelectionOp/LoopOp's region.
2443 mapper.map(mergeBlock, &body.back());
2444
2445 collectBlocksInConstruct();
2446
2447 // We've identified all blocks belonging to the selection/loop's region. Now
2448 // need to "move" them into the selection/loop. Instead of really moving the
2449 // blocks, in the following we copy them and remap all values and branches.
2450 // This is because:
2451 // * Inserting a block into a region requires the block not in any region
2452 // before. But selections/loops can nest so we can create selection/loop ops
2453 // in a nested manner, which means some blocks may already be in a
2454 // selection/loop region when to be moved again.
2455 // * It's much trickier to fix up the branches into and out of the loop's
2456 // region: we need to treat not-moved blocks and moved blocks differently:
2457 // Not-moved blocks jumping to the loop header block need to jump to the
2458 // merge point containing the new loop op but not the loop continue block's
2459 // back edge. Moved blocks jumping out of the loop need to jump to the
2460 // merge block inside the loop region but not other not-moved blocks.
2461 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2462 // logic.
2463
2464 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2465 // block in this loop construct.
2466 OpBuilder builder(body);
2467 for (auto *block : constructBlocks) {
2468 // Create a block and insert it before the selection/loop merge block in the
2469 // SelectionOp/LoopOp's region.
2470 auto *newBlock = builder.createBlock(&body.back());
2471 mapper.map(block, newBlock);
2472 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2473 << " from block " << block << "\n");
2474 if (!isFnEntryBlock(block)) {
2475 for (BlockArgument blockArg : block->getArguments()) {
2476 auto newArg =
2477 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2478 mapper.map(blockArg, newArg);
2479 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2480 << blockArg << " to " << newArg << "\n");
2481 }
2482 } else {
2483 LLVM_DEBUG(logger.startLine()
2484 << "[cf] block " << block << " is a function entry block\n");
2485 }
2486
2487 for (auto &op : *block)
2488 newBlock->push_back(op.clone(mapper));
2489 }
2490
2491 // Go through all ops and remap the operands.
2492 auto remapOperands = [&](Operation *op) {
2493 for (auto &operand : op->getOpOperands())
2494 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2495 operand.set(mappedOp);
2496 for (auto &succOp : op->getBlockOperands())
2497 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2498 succOp.set(mappedOp);
2499 };
2500 for (auto &block : body)
2501 block.walk(remapOperands);
2502
2503 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2504 // the selection/loop construct into its region. Next we need to fix the
2505 // connections between this new SelectionOp/LoopOp with existing blocks.
2506
2507 // All existing incoming branches should go to the merge block, where the
2508 // SelectionOp/LoopOp resides right now.
2509 headerBlock->replaceAllUsesWith(mergeBlock);
2510
2511 LLVM_DEBUG({
2512 logger.startLine() << "[cf] after cloning and fixing references:\n";
2513 headerBlock->getParentOp()->print(logger.getOStream());
2514 logger.startLine() << "\n";
2515 });
2516
2517 if (isLoop) {
2518 if (!mergeBlock->args_empty()) {
2519 return mergeBlock->getParentOp()->emitError(
2520 "OpPhi in loop merge block unsupported");
2521 }
2522
2523 // The loop header block may have block arguments. Since now we place the
2524 // loop op inside the old merge block, we need to make sure the old merge
2525 // block has the same block argument list.
2526 for (BlockArgument blockArg : headerBlock->getArguments())
2527 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2528
2529 // If the loop header block has block arguments, make sure the spirv.Branch
2530 // op matches.
2531 SmallVector<Value, 4> blockArgs;
2532 if (!headerBlock->args_empty())
2533 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2534
2535 // The loop entry block should have a unconditional branch jumping to the
2536 // loop header block.
2537 builder.setInsertionPointToEnd(&body.front());
2538 spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2539 ArrayRef<Value>(blockArgs));
2540 }
2541
2542 // Values defined inside the selection region that need to be yielded outside
2543 // the region.
2544 SmallVector<Value> valuesToYield;
2545 // Outside uses of values that were sunk into the selection region. Those uses
2546 // will be replaced with values returned by the SelectionOp.
2547 SmallVector<Value> outsideUses;
2548
2549 // Move block arguments of the original block (`mergeBlock`) into the merge
2550 // block inside the selection (`body.back()`). Values produced by block
2551 // arguments will be yielded by the selection region. We do not update uses or
2552 // erase original block arguments yet. It will be done later in the code.
2553 //
2554 // Code below is not executed for loops as it would interfere with the logic
2555 // above. Currently block arguments in the merge block are not supported, but
2556 // instead, the code above copies those arguments from the header block into
2557 // the merge block. As such, running the code would yield those copied
2558 // arguments that is most likely not a desired behaviour. This may need to be
2559 // revisited in the future.
2560 if (!isLoop)
2561 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2562 // Create new block arguments in the last block ("merge block") of the
2563 // selection region. We create one argument for each argument in
2564 // `mergeBlock`. This new value will need to be yielded, and the original
2565 // value replaced, so add them to appropriate vectors.
2566 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2567 valuesToYield.push_back(body.back().getArguments().back());
2568 outsideUses.push_back(blockArg);
2569 }
2570
2571 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2572 // cleaned up.
2573 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2574 // First we need to drop all operands' references inside all blocks. This is
2575 // needed because we can have blocks referencing SSA values from one another.
2576 for (auto *block : constructBlocks)
2577 block->dropAllReferences();
2578
2579 // All internal uses should be removed from original blocks by now, so
2580 // whatever is left is an outside use and will need to be yielded from
2581 // the newly created selection / loop region.
2582 for (Block *block : constructBlocks) {
2583 for (Operation &op : *block) {
2584 if (!op.use_empty())
2585 for (Value result : op.getResults()) {
2586 valuesToYield.push_back(mapper.lookupOrNull(result));
2587 outsideUses.push_back(result);
2588 }
2589 }
2590 for (BlockArgument &arg : block->getArguments()) {
2591 if (!arg.use_empty()) {
2592 valuesToYield.push_back(mapper.lookupOrNull(arg));
2593 outsideUses.push_back(arg);
2594 }
2595 }
2596 }
2597
2598 assert(valuesToYield.size() == outsideUses.size());
2599
2600 // If we need to yield any values from the selection / loop region we will
2601 // take care of it here.
2602 if (!valuesToYield.empty()) {
2603 LLVM_DEBUG(logger.startLine()
2604 << "[cf] yielding values from the selection / loop region\n");
2605
2606 // Update `mlir.merge` with values to be yield.
2607 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2608 Operation *merge = llvm::getSingleElement(mergeOps);
2609 assert(merge);
2610 merge->setOperands(valuesToYield);
2611
2612 // MLIR does not allow changing the number of results of an operation, so
2613 // we create a new SelectionOp / LoopOp with required list of results and
2614 // move the region from the initial SelectionOp / LoopOp. The initial
2615 // operation is then removed. Since we move the region to the new op all
2616 // links between blocks and remapping we have previously done should be
2617 // preserved.
2618 builder.setInsertionPoint(&mergeBlock->front());
2619
2620 Operation *newOp = nullptr;
2621
2622 if (isLoop)
2623 newOp = spirv::LoopOp::create(builder, location,
2624 TypeRange(ValueRange(outsideUses)),
2625 static_cast<spirv::LoopControl>(control));
2626 else
2627 newOp = spirv::SelectionOp::create(
2628 builder, location, TypeRange(ValueRange(outsideUses)),
2629 static_cast<spirv::SelectionControl>(control));
2630
2631 newOp->getRegion(0).takeBody(body);
2632
2633 // Remove initial op and swap the pointer to the newly created one.
2634 op->erase();
2635 op = newOp;
2636
2637 // Update all outside uses to use results of the SelectionOp / LoopOp and
2638 // remove block arguments from the original merge block.
2639 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2640 outsideUses[i].replaceAllUsesWith(op->getResult(i));
2641
2642 // We do not support block arguments in loop merge block. Also running this
2643 // function with loop would break some of the loop specific code above
2644 // dealing with block arguments.
2645 if (!isLoop)
2646 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2647 }
2648
2649 // Check that whether some op in the to-be-erased blocks still has uses. Those
2650 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2651 // region. We cannot handle such cases given that once a value is sinked into
2652 // the SelectionOp/LoopOp's region, there is no escape for it.
2653 for (auto *block : constructBlocks) {
2654 if (!block->use_empty())
2655 return emitError(block->getParent()->getLoc(),
2656 "failed control flow structurization: "
2657 "block has uses outside of the "
2658 "enclosing selection/loop construct");
2659 for (Operation &op : *block)
2660 if (!op.use_empty())
2661 return op.emitOpError("failed control flow structurization: value has "
2662 "uses outside of the "
2663 "enclosing selection/loop construct");
2664 for (BlockArgument &arg : block->getArguments())
2665 if (!arg.use_empty())
2666 return emitError(arg.getLoc(), "failed control flow structurization: "
2667 "block argument has uses outside of the "
2668 "enclosing selection/loop construct");
2669 }
2670
2671 // Then erase all old blocks.
2672 for (auto *block : constructBlocks) {
2673 // We've cloned all blocks belonging to this construct into the structured
2674 // control flow op's region. Among these blocks, some may compose another
2675 // selection/loop. If so, they will be recorded within blockMergeInfo.
2676 // We need to update the pointers there to the newly remapped ones so we can
2677 // continue structurizing them later.
2678 //
2679 // We need to walk each block as constructBlocks do not include blocks
2680 // internal to ops already structured within those blocks. It is not
2681 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2682 // inside already structured selections/loops get invalidated and needs
2683 // updating, however the following example code can cause a crash (depending
2684 // on the structuring order), when the most inner selection is being
2685 // structured after the outer selection and loop have been already
2686 // structured:
2687 //
2688 // spirv.mlir.for {
2689 // // ...
2690 // spirv.mlir.selection {
2691 // // ..
2692 // // A selection region that hasn't been yet structured!
2693 // // ..
2694 // }
2695 // // ...
2696 // }
2697 //
2698 // If the loop gets structured after the outer selection, but before the
2699 // inner selection. Moving the already structured selection inside the loop
2700 // will invalidate the mergeInfo of the region that is not yet structured.
2701 // Just going over constructBlocks will not check and updated header blocks
2702 // inside the already structured selection region. Walking block fixes that.
2703 //
2704 // TODO: If structuring was done in a fixed order starting with inner
2705 // most constructs this most likely not be an issue and the whole code
2706 // section could be removed. However, with the current non-deterministic
2707 // order this is not possible.
2708 //
2709 // TODO: The asserts in the following assumes input SPIR-V blob forms
2710 // correctly nested selection/loop constructs. We should relax this and
2711 // support error cases better.
2712 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2713 auto it = blockMergeInfo.find(block);
2714 if (it != blockMergeInfo.end()) {
2715 // Use the original location for nested selection/loop ops.
2716 Location loc = it->second.loc;
2717
2718 Block *newHeader = mapper.lookupOrNull(block);
2719 if (!newHeader)
2720 return emitError(loc, "failed control flow structurization: nested "
2721 "loop header block should be remapped!");
2722
2723 Block *newContinue = it->second.continueBlock;
2724 if (newContinue) {
2725 newContinue = mapper.lookupOrNull(newContinue);
2726 if (!newContinue)
2727 return emitError(loc, "failed control flow structurization: nested "
2728 "loop continue block should be remapped!");
2729 }
2730
2731 Block *newMerge = it->second.mergeBlock;
2732 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2733 newMerge = mappedTo;
2734
2735 // The iterator should be erased before adding a new entry into
2736 // blockMergeInfo to avoid iterator invalidation.
2737 blockMergeInfo.erase(it);
2738 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2739 newContinue);
2740 }
2741
2742 return WalkResult::advance();
2743 };
2744
2745 if (block->walk(updateMergeInfo).wasInterrupted())
2746 return failure();
2747
2748 // The structured selection/loop's entry block does not have arguments.
2749 // If the function's header block is also part of the structured control
2750 // flow, we cannot just simply erase it because it may contain arguments
2751 // matching the function signature and used by the cloned blocks.
2752 if (isFnEntryBlock(block)) {
2753 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2754 << " to only contain a spirv.Branch op\n");
2755 // Still keep the function entry block for the potential block arguments,
2756 // but replace all ops inside with a branch to the merge block.
2757 block->clear();
2758 builder.setInsertionPointToEnd(block);
2759 spirv::BranchOp::create(builder, location, mergeBlock);
2760 } else {
2761 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2762 block->erase();
2763 }
2764 }
2765
2766 LLVM_DEBUG(logger.startLine()
2767 << "[cf] after structurizing construct with header block "
2768 << headerBlock << ":\n"
2769 << *op << "\n");
2770
2771 return success();
2772}
2773
2775 LLVM_DEBUG({
2776 logger.startLine()
2777 << "//----- [phi] start wiring up block arguments -----//\n";
2778 logger.indent();
2779 });
2780
2781 OpBuilder::InsertionGuard guard(opBuilder);
2782
2783 for (const auto &info : blockPhiInfo) {
2784 Block *block = info.first.first;
2785 Block *target = info.first.second;
2786 const BlockPhiInfo &phiInfo = info.second;
2787 LLVM_DEBUG({
2788 logger.startLine() << "[phi] block " << block << "\n";
2789 logger.startLine() << "[phi] before creating block argument:\n";
2790 block->getParentOp()->print(logger.getOStream());
2791 logger.startLine() << "\n";
2792 });
2793
2794 // Set insertion point to before this block's terminator early because we
2795 // may materialize ops via getValue() call.
2796 auto *op = block->getTerminator();
2797 opBuilder.setInsertionPoint(op);
2798
2799 SmallVector<Value, 4> blockArgs;
2800 blockArgs.reserve(phiInfo.size());
2801 for (uint32_t valueId : phiInfo) {
2802 if (Value value = getValue(valueId)) {
2803 blockArgs.push_back(value);
2804 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2805 << " id = " << valueId << "\n");
2806 } else {
2807 return emitError(unknownLoc, "OpPhi references undefined value!");
2808 }
2809 }
2810
2811 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2812 // Replace the previous branch op with a new one with block arguments.
2813 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2814 branchOp.getTarget(), blockArgs);
2815 branchOp.erase();
2816 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2817 assert((branchCondOp.getTrueBlock() == target ||
2818 branchCondOp.getFalseBlock() == target) &&
2819 "expected target to be either the true or false target");
2820 if (target == branchCondOp.getTrueTarget())
2821 spirv::BranchConditionalOp::create(
2822 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2823 blockArgs, branchCondOp.getFalseBlockArguments(),
2824 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2825 branchCondOp.getFalseTarget());
2826 else
2827 spirv::BranchConditionalOp::create(
2828 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2829 branchCondOp.getTrueBlockArguments(), blockArgs,
2830 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2831 branchCondOp.getFalseBlock());
2832
2833 branchCondOp.erase();
2834 } else {
2835 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2836 }
2837
2838 LLVM_DEBUG({
2839 logger.startLine() << "[phi] after creating block argument:\n";
2840 block->getParentOp()->print(logger.getOStream());
2841 logger.startLine() << "\n";
2842 });
2843 }
2844 blockPhiInfo.clear();
2845
2846 LLVM_DEBUG({
2847 logger.unindent();
2848 logger.startLine()
2849 << "//--- [phi] completed wiring up block arguments ---//\n";
2850 });
2851 return success();
2852}
2853
2855 // Create a copy, so we can modify keys in the original.
2856 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2857 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2858 it != e; ++it) {
2859 auto &[block, mergeInfo] = *it;
2860
2861 // Skip processing loop regions. For loop regions continueBlock is non-null.
2862 if (mergeInfo.continueBlock)
2863 continue;
2864
2865 if (!block->mightHaveTerminator())
2866 continue;
2867
2868 Operation *terminator = block->getTerminator();
2869 assert(terminator);
2870
2871 if (!isa<spirv::BranchConditionalOp>(terminator))
2872 continue;
2873
2874 // Check if the current header block is a merge block of another construct.
2875 bool splitHeaderMergeBlock = false;
2876 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2877 if (mergeInfo.mergeBlock == block)
2878 splitHeaderMergeBlock = true;
2879 }
2880
2881 // Do not split a block that only contains a conditional branch, unless it
2882 // is also a merge block of another construct - in that case we want to
2883 // split the block. We do not want two constructs to share header / merge
2884 // block.
2885 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2886 Block *newBlock = block->splitBlock(terminator);
2887 OpBuilder builder(block, block->end());
2888 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2889
2890 // After splitting we need to update the map to use the new block as a
2891 // header.
2892 blockMergeInfo.erase(block);
2893 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2894 }
2895 }
2896
2897 return success();
2898}
2899
2901 if (!options.enableControlFlowStructurization) {
2902 LLVM_DEBUG(
2903 {
2904 logger.startLine()
2905 << "//----- [cf] skip structurizing control flow -----//\n";
2906 logger.indent();
2907 });
2908 return success();
2909 }
2910
2911 LLVM_DEBUG({
2912 logger.startLine()
2913 << "//----- [cf] start structurizing control flow -----//\n";
2914 logger.indent();
2915 });
2916
2917 LLVM_DEBUG({
2918 logger.startLine() << "[cf] split conditional blocks\n";
2919 logger.startLine() << "\n";
2920 });
2921
2922 if (failed(splitConditionalBlocks())) {
2923 return failure();
2924 }
2925
2926 // TODO: This loop is non-deterministic. Iteration order may vary between runs
2927 // for the same shader as the key to the map is a pointer. See:
2928 // https://github.com/llvm/llvm-project/issues/128547
2929 while (!blockMergeInfo.empty()) {
2930 Block *headerBlock = blockMergeInfo.begin()->first;
2931 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2932
2933 LLVM_DEBUG({
2934 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2935 headerBlock->print(logger.getOStream());
2936 logger.startLine() << "\n";
2937 });
2938
2939 auto *mergeBlock = mergeInfo.mergeBlock;
2940 assert(mergeBlock && "merge block cannot be nullptr");
2941 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2942 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2943 LLVM_DEBUG({
2944 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2945 mergeBlock->print(logger.getOStream());
2946 logger.startLine() << "\n";
2947 });
2948
2949 auto *continueBlock = mergeInfo.continueBlock;
2950 LLVM_DEBUG(if (continueBlock) {
2951 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2952 continueBlock->print(logger.getOStream());
2953 logger.startLine() << "\n";
2954 });
2955 // Erase this case before calling into structurizer, who will update
2956 // blockMergeInfo.
2957 blockMergeInfo.erase(blockMergeInfo.begin());
2958 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2959 blockMergeInfo, headerBlock,
2960 mergeBlock, continueBlock
2961#ifndef NDEBUG
2962 ,
2963 logger
2964#endif
2965 );
2966 if (failed(structurizer.structurize()))
2967 return failure();
2968 }
2969
2970 LLVM_DEBUG({
2971 logger.unindent();
2972 logger.startLine()
2973 << "//--- [cf] completed structurizing control flow ---//\n";
2974 });
2975 return success();
2976}
2977
2978//===----------------------------------------------------------------------===//
2979// Debug
2980//===----------------------------------------------------------------------===//
2981
2983 if (!debugLine)
2984 return unknownLoc;
2985
2986 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2987 if (fileName.empty())
2988 fileName = "<unknown>";
2989 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2990 debugLine->column);
2991}
2992
2993LogicalResult
2995 // According to SPIR-V spec:
2996 // "This location information applies to the instructions physically
2997 // following this instruction, up to the first occurrence of any of the
2998 // following: the next end of block, the next OpLine instruction, or the next
2999 // OpNoLine instruction."
3000 if (operands.size() != 3)
3001 return emitError(unknownLoc, "OpLine must have 3 operands");
3002 debugLine = DebugLine{operands[0], operands[1], operands[2]};
3003 return success();
3004}
3005
3006void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
3007
3008LogicalResult
3010 if (operands.size() < 2)
3011 return emitError(unknownLoc, "OpString needs at least 2 operands");
3012
3013 if (!debugInfoMap.lookup(operands[0]).empty())
3014 return emitError(unknownLoc,
3015 "duplicate debug string found for result <id> ")
3016 << operands[0];
3017
3018 unsigned wordIndex = 1;
3019 StringRef debugString = decodeStringLiteral(operands, wordIndex);
3020 if (wordIndex != operands.size())
3021 return emitError(unknownLoc,
3022 "unexpected trailing words in OpString instruction");
3023
3024 debugInfoMap[operands[0]] = debugString;
3025 return success();
3026}
return success()
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
ArrayAttr()
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:318
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:99
iterator begin()
Definition Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< BlockOperand > getBlockOperands()
Definition Operation.h:695
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
iterator end()
Definition Region.h:56
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:241
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult splitConditionalBlocks()
Move a conditional branch into a separate basic block to avoid unnecessary sinking of defs that may b...
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
constexpr uint32_t kMagicNumber
SPIR-V magic number.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.