MLIR 23.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the serialization methods for MLIR SPIR-V module ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
20#include "llvm/ADT/DepthFirstIterator.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24
25#define DEBUG_TYPE "spirv-serialization"
26
27using namespace mlir;
28
29namespace {
30// Location::print() emits MLIR syntax such as `loc("name")` or
31// `loc(fused["op", "file":1:2])`. NonSemantic.Graph.DebugInfo stores the
32// source/debug name itself in an OpString, so keep this conversion to the
33// payload string explicit.
34std::string getDebugInfoStringFromLoc(Location loc) {
35 if (auto fileLineCol = dyn_cast<FileLineColLoc>(loc)) {
36 return llvm::formatv("{0}:{1}:{2}", fileLineCol.getFilename(),
37 fileLineCol.getLine(), fileLineCol.getColumn());
38 }
39 if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
40 return nameLoc.getName().str();
41 }
42 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
43 std::string result;
44 llvm::raw_string_ostream os(result);
45 llvm::interleave(
46 map_range(fusedLoc.getLocations(), getDebugInfoStringFromLoc), os, ";");
47 return result;
48 }
49 return "";
50}
51} // namespace
52
53/// A pre-order depth-first visitor function for processing basic blocks.
54///
55/// Visits the basic blocks starting from the given `headerBlock` in pre-order
56/// depth-first manner and calls `blockHandler` on each block. Skips handling
57/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
58/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
59/// successors.
60///
61/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
62/// of blocks in a function must satisfy the rule that blocks appear before
63/// all blocks they dominate." This can be achieved by a pre-order CFG
64/// traversal algorithm. To make the serialization output more logical and
65/// readable to human, we perform depth-first CFG traversal and delay the
66/// serialization of the merge block and the continue block, if exists, until
67/// after all other blocks have been processed.
68static LogicalResult
70 function_ref<LogicalResult(Block *)> blockHandler,
71 bool skipHeader = false, BlockRange skipBlocks = {}) {
72 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
73 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
74
75 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
76 if (skipHeader && block == headerBlock)
77 continue;
78 if (failed(blockHandler(block)))
79 return failure();
80 }
81 return success();
82}
83
84namespace mlir {
85namespace spirv {
86LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
87 if (auto resultID =
88 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
89 valueIDMap[op.getResult()] = resultID;
90 if (isa<spirv::TensorArmType>(op.getType()) &&
91 failed(encodeDebugInfoTensorInst(op.getResult())))
92 return failure();
93 return success();
94 }
95 return failure();
96}
97
98LogicalResult Serializer::processConstantCompositeReplicateOp(
99 spirv::EXTConstantCompositeReplicateOp op) {
100 if (uint32_t resultID = prepareConstantCompositeReplicate(
101 op.getLoc(), op.getType(), op.getValue())) {
102 valueIDMap[op.getResult()] = resultID;
103 return success();
104 }
105 return failure();
106}
107
108LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
109 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
110 /*isSpec=*/true)) {
111 // Emit the OpDecorate instruction for SpecId.
112 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
113 auto val = static_cast<uint32_t>(specID.getInt());
114 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
115 return failure();
116 }
117
118 specConstIDMap[op.getSymName()] = resultID;
119 return processName(resultID, op.getSymName());
120 }
121 return failure();
122}
123
124LogicalResult
125Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
126 uint32_t typeID = 0;
127 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
128 return failure();
129 }
130
131 auto resultID = getNextID();
132
134 operands.push_back(typeID);
135 operands.push_back(resultID);
136
137 auto constituents = op.getConstituents();
138
139 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
140 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
141
142 auto constituentName = constituent.getValue();
143 auto constituentID = getSpecConstID(constituentName);
144
145 if (!constituentID) {
146 return op.emitError("unknown result <id> for specialization constant ")
147 << constituentName;
148 }
149
150 operands.push_back(constituentID);
151 }
152
153 encodeInstructionWithContinuationInto(
154 typesGlobalValues, spirv::Opcode::OpSpecConstantComposite, operands);
155 specConstIDMap[op.getSymName()] = resultID;
156
157 return processName(resultID, op.getSymName());
158}
159
160LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
161 spirv::EXTSpecConstantCompositeReplicateOp op) {
162 uint32_t typeID = 0;
163 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
164 return failure();
165 }
166
167 auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
168 if (!constituent)
169 return op.emitError(
170 "expected flat symbol reference for constituent instead of ")
171 << op.getConstituent();
172
173 StringRef constituentName = constituent.getValue();
174 uint32_t constituentID = getSpecConstID(constituentName);
175 if (!constituentID) {
176 return op.emitError("unknown result <id> for replicated spec constant ")
177 << constituentName;
179
180 uint32_t resultID = getNextID();
181 uint32_t operands[] = {typeID, resultID, constituentID};
182
183 encodeInstructionInto(typesGlobalValues,
184 spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
185 operands);
186
187 specConstIDMap[op.getSymName()] = resultID;
188
189 return processName(resultID, op.getSymName());
190}
191
192LogicalResult
193Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
194 uint32_t typeID = 0;
195 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
196 return failure();
197 }
198
199 auto resultID = getNextID();
200
201 SmallVector<uint32_t, 8> operands;
202 operands.push_back(typeID);
203 operands.push_back(resultID);
204
205 Block &block = op.getRegion().getBlocks().front();
206 Operation &enclosedOp = block.getOperations().front();
207
208 std::string enclosedOpName;
209 llvm::raw_string_ostream rss(enclosedOpName);
210 rss << "Op" << enclosedOp.getName().stripDialect();
211 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
212
213 if (!enclosedOpcode) {
214 op.emitError("Couldn't find op code for op ")
215 << enclosedOp.getName().getStringRef();
216 return failure();
217 }
218
219 operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
220
221 // Append operands to the enclosed op to the list of operands.
222 for (Value operand : enclosedOp.getOperands()) {
223 uint32_t id = getValueID(operand);
224 assert(id && "use before def!");
225 operands.push_back(id);
226 }
227
228 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
229 operands);
230 valueIDMap[op.getResult()] = resultID;
231
232 return success();
233}
234
235LogicalResult
236Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
237 if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
238 op.getGraphConstantIdAttr())) {
239 valueIDMap[op.getResult()] = resultID;
240 return success();
241 }
242 return failure();
243}
244
245LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
246 auto undefType = op.getType();
247 auto &id = undefValIDMap[undefType];
248 if (!id) {
249 id = getNextID();
250 uint32_t typeID = 0;
251 if (failed(processType(op.getLoc(), undefType, typeID)))
252 return failure();
253 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
254 {typeID, id});
255 }
256 valueIDMap[op.getResult()] = id;
257 return success();
258}
259
260LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
261 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
262 uint32_t argTypeID = 0;
263 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
264 return failure();
265 }
266 auto argValueID = getNextID();
267
268 // Process decoration attributes of arguments.
269 auto funcOp = cast<FunctionOpInterface>(*op);
270 for (auto argAttr : funcOp.getArgAttrs(idx)) {
271 if (argAttr.getName() != DecorationAttr::name)
272 continue;
273
274 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
275 if (failed(processDecorationAttr(op->getLoc(), argValueID,
276 decAttr.getValue(), decAttr)))
277 return failure();
278 }
279 }
280
281 valueIDMap[arg] = argValueID;
282 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
283 {argTypeID, argValueID});
284 }
285 return success();
286}
287
288LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
289 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
290 assert(functionHeader.empty() && functionBody.empty());
291
292 uint32_t fnTypeID = 0;
293 // Generate type of the function.
294 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
295 return failure();
296
297 // Add the function definition.
298 SmallVector<uint32_t, 4> operands;
299 uint32_t resTypeID = 0;
300 auto resultTypes = op.getFunctionType().getResults();
301 if (resultTypes.size() > 1) {
302 return op.emitError("cannot serialize function with multiple return types");
304 if (failed(processType(op.getLoc(),
305 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
306 resTypeID))) {
307 return failure();
308 }
309 operands.push_back(resTypeID);
310 auto funcID = getOrCreateFunctionID(op.getName());
311 operands.push_back(funcID);
312 operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
313 operands.push_back(fnTypeID);
314 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
315
316 // Add function name.
317 if (failed(processName(funcID, op.getName()))) {
318 return failure();
319 }
320 // Handle external functions with linkage_attributes(LinkageAttributes)
321 // differently.
322 auto linkageAttr = op.getLinkageAttributes();
323 auto hasImportLinkage =
324 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
325 spirv::LinkageType::Import);
326 if (op.isExternal() && !hasImportLinkage) {
327 return op.emitError(
328 "'spirv.module' cannot contain external functions "
329 "without 'Import' linkage_attributes (LinkageAttributes)");
330 }
331 if (op.isExternal() && hasImportLinkage) {
332 // Add an entry block to set up the block arguments
333 // to match the signature of the function.
334 // This is to generate OpFunctionParameter for functions with
335 // LinkageAttributes.
336 // WARNING: This operation has side-effect, it essentially adds a body
337 // to the func. Hence, making it not external anymore (isExternal()
338 // is going to return false for this function from now on)
339 // Hence, we'll remove the body once we are done with the serialization.
340 op.addEntryBlock();
341 if (failed(processFuncParameter(op)))
342 return failure();
343
344 // Erasing the body of the function destroys arguments, so we need to remove
345 // them from the map to avoid problems when processing invalid values used
346 // as keys. We have already serialized function arguments so we probably can
347 // remove them from the map as external function will not have any uses.
348 for (Value arg : op.getArguments())
349 valueIDMap.erase(arg);
350
351 // Don't need to process the added block, there is nothing to process,
352 // the fake body was added just to get the arguments, remove the body,
353 // since it's use is done.
354 op.eraseBody();
355 } else {
356 if (failed(processFuncParameter(op)))
357 return failure();
359 // Some instructions (e.g., OpVariable) in a function must be in the first
360 // block in the function. These instructions will be put in
361 // functionHeader. Thus, we put the label in functionHeader first, and
362 // omit it from the first block. OpLabel only needs to be added for
363 // functions with body (including empty body). Since, we added a fake body
364 // for functions with 'Import' Linkage attributes, these functions are
365 // essentially function delcaration, so they should not have OpLabel and a
366 // terminating instruction. That's why we skipped it for those functions.
367 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
368 {getOrCreateBlockID(&op.front())});
369 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
370 return failure();
371 if (failed(visitInPrettyBlockOrder(
372 &op.front(), [&](Block *block) { return processBlock(block); },
373 /*skipHeader=*/true))) {
374 return failure();
375 }
376
377 // There might be OpPhi instructions who have value references needing to
378 // fix.
379 for (const auto &deferredValue : deferredPhiValues) {
380 Value value = deferredValue.first;
381 uint32_t id = getValueID(value);
382 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
383 << " to id = " << id << '\n');
384 assert(id && "OpPhi references undefined value!");
385 for (size_t offset : deferredValue.second)
386 functionBody[offset] = id;
387 }
388 deferredPhiValues.clear();
389 }
390 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
391 << "' --\n");
392 // Insert Decorations based on Function Attributes.
393 // Only attributes we should be considering for decoration are the
394 // ::mlir::spirv::Decoration attributes.
395
396 for (auto attr : op->getAttrs()) {
397 // Only generate OpDecorate op for spirv::Decoration attributes.
398 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
399 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
400 /*capitalizeFirst=*/true));
401 if (isValidDecoration != std::nullopt) {
402 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
403 return failure();
404 }
405 }
406 }
407 // Insert OpFunctionEnd.
408 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
409
410 functions.append(functionHeader.begin(), functionHeader.end());
411 functions.append(functionBody.begin(), functionBody.end());
412 functionHeader.clear();
413 functionBody.clear();
414
415 return success();
416}
417
418uint32_t Serializer::encodeDebugStringInst(StringRef str) {
419 uint32_t stringID = debugStringIDMap.lookup(str);
420 if (stringID > 0) {
421 return stringID;
422 }
423
424 SmallVector<uint32_t, 2> operands;
425 stringID = getNextID();
426 debugStringIDMap[str] = stringID;
427 operands.push_back(stringID);
428 spirv::encodeStringLiteralInto(operands, str);
429 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
430
431 return stringID;
432}
433
434LogicalResult Serializer::encodeDebugInfoGraphInst(spirv::GraphARMOp op,
435 uint32_t &debugGraphID) {
436 if (!options.emitDebugInfo)
437 return success();
438
439 uint32_t voidTypeID = 0;
440 if (failed(processType(op.getLoc(), getVoidType(), voidTypeID)))
441 return failure();
442
443 uint32_t stringID =
444 encodeDebugStringInst(getDebugInfoStringFromLoc(op.getLoc()));
445
447 operands.push_back(voidTypeID);
448 debugGraphID = getNextID();
449 operands.push_back(debugGraphID);
450 uint32_t graphID = getOrCreateFunctionID(op.getName());
451 operands.push_back(graphID);
452 operands.push_back(stringID);
453
454 if (failed(encodeExtensionInstruction(
455 nullptr, extDebugInfo,
456 static_cast<uint32_t>(GraphDebugInfoExtInst::DebugGraph), operands,
457 graphsDebugInfo)))
458 return failure();
459
460 return success();
461}
462
463LogicalResult
464Serializer::encodeDebugInfoOperationInst(uint32_t debugGraphID,
465 const SetVector<Operation *> &ops) {
466 if (!options.emitDebugInfo)
467 return success();
468
469 if (ops.empty())
470 return success();
471
472 SmallVector<uint32_t, 4> instructionIDs;
473 for (Operation *op : ops)
474 for (OpResult result : op->getOpResults())
475 instructionIDs.push_back(getValueID(result));
476
477 if (instructionIDs.empty())
478 return success();
479
480 uint32_t voidTypeID = 0;
481 if (failed(processType(ops[0]->getLoc(), getVoidType(), voidTypeID)))
482 return failure();
483
484 uint32_t stringID =
485 encodeDebugStringInst(getDebugInfoStringFromLoc(ops[0]->getLoc()));
486
488 operands.push_back(voidTypeID);
489 operands.push_back(getNextID());
490 operands.push_back(debugGraphID);
491 operands.push_back(stringID);
492 operands.append(instructionIDs);
493
494 if (failed(encodeExtensionInstruction(
495 nullptr, extDebugInfo,
496 static_cast<uint32_t>(GraphDebugInfoExtInst::DebugOperation),
497 operands, graphsDebugInfo)))
498 return failure();
499
500 return success();
501}
502
503LogicalResult Serializer::encodeDebugInfoTensorInst(Value tensor) {
504 if (!options.emitDebugInfo)
505 return success();
506
507 uint32_t voidTypeID = 0;
508 if (failed(processType(tensor.getLoc(), getVoidType(), voidTypeID)))
509 return failure();
510
511 uint32_t tensorID = valueIDMap.lookup(tensor);
512 if (tensorID == 0)
513 return success();
514
515 uint32_t stringID =
516 encodeDebugStringInst(getDebugInfoStringFromLoc(tensor.getLoc()));
517
519 operands.push_back(voidTypeID);
520 operands.push_back(getNextID());
521 operands.push_back(tensorID);
522 operands.push_back(stringID);
523
524 if (failed(encodeExtensionInstruction(
525 nullptr, extDebugInfo,
526 static_cast<uint32_t>(GraphDebugInfoExtInst::DebugTensor), operands,
527 graphsDebugInfo)))
528 return failure();
529
530 return success();
531}
532
533LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
534 if (op.getNumResults() < 1) {
535 return op.emitError("cannot serialize graph with no return types");
536 }
537
538 LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
539 assert(functionHeader.empty() && functionBody.empty());
540
541 uint32_t funcID = getOrCreateFunctionID(op.getName());
542 uint32_t fnTypeID = 0;
543 // Generate type of the function.
544 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
545 return failure();
546 encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
547 {fnTypeID, funcID});
548
549 // Declare the parameters.
550 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
551 uint32_t argTypeID = 0;
552 SmallVector<uint32_t, 3> inputOperands;
553
554 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
555 return failure();
556 }
557
558 uint32_t argValueID = getNextID();
559 valueIDMap[arg] = argValueID;
560
561 auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
562 uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
563
564 inputOperands.push_back(argTypeID);
565 inputOperands.push_back(argValueID);
566 inputOperands.push_back(indexID);
567
568 encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
569 inputOperands);
570
571 if (failed(encodeDebugInfoTensorInst(arg)))
572 return failure();
573 }
574
575 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
576 return failure();
578 &op.front(), [&](Block *block) { return processBlock(block); },
579 /*skipHeader=*/true))) {
580 return failure();
581 }
583 LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
584 << "' --\n");
585 // Insert OpGraphEndARM.
586 encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
587
588 llvm::append_range(graphs, functionHeader);
589 llvm::append_range(graphs, functionBody);
590 functionHeader.clear();
591 functionBody.clear();
592
593 uint32_t debugGraphID = 0;
594 if (failed(encodeDebugInfoGraphInst(op, debugGraphID)))
595 return failure();
596
597 for (const auto &debugEntry : tosaOpsMap[funcID]) {
598 if (failed(encodeDebugInfoOperationInst(debugGraphID, debugEntry.second)))
599 return failure();
600 }
601
602 return success();
603}
604
605LogicalResult
606Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
607 SmallVector<uint32_t, 4> operands;
608 StringRef graph = op.getFn();
609 // Add the graph <id>.
610 uint32_t graphID = getOrCreateFunctionID(graph);
611 operands.push_back(graphID);
612 // Add the name of the graph.
613 spirv::encodeStringLiteralInto(operands, graph);
614
615 // Add the interface values.
616 if (ArrayAttr interface = op.getInterface()) {
617 for (Attribute var : interface.getValue()) {
618 StringRef value = cast<FlatSymbolRefAttr>(var).getValue();
619 if (uint32_t id = getVariableID(value)) {
620 operands.push_back(id);
621 } else {
622 return op.emitError(
623 "referencing undefined global variable."
624 "spirv.GraphEntryPointARM is at the end of spirv.module. All "
625 "referenced variables should already be defined");
626 }
627 }
628 }
629 encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
630 return success();
631}
632
633LogicalResult
634Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
635 for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
636 SmallVector<uint32_t, 2> outputOperands;
638 Type resType = value.getType();
639 uint32_t resTypeID = 0;
640 if (failed(processType(op.getLoc(), resType, resTypeID))) {
641 return failure();
642 }
643
644 uint32_t outputID = getValueID(value);
645 auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
646 uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
647
648 outputOperands.push_back(outputID);
649 outputOperands.push_back(indexID);
650
651 if (failed(encodeDebugInfoTensorInst(value)))
652 return failure();
653
654 encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
655 outputOperands);
656 }
657 return success();
658}
659
660LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
661 SmallVector<uint32_t, 4> operands;
662 SmallVector<StringRef, 2> elidedAttrs;
663 uint32_t resultID = 0;
664 uint32_t resultTypeID = 0;
665 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
666 return failure();
667 }
668 operands.push_back(resultTypeID);
669 resultID = getNextID();
670 valueIDMap[op.getResult()] = resultID;
671 operands.push_back(resultID);
672 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
673 if (attr) {
674 operands.push_back(
675 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
676 }
677 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
678 for (auto arg : op.getODSOperands(0)) {
679 auto argID = getValueID(arg);
680 if (!argID) {
681 return emitError(op.getLoc(), "operand 0 has a use before def");
682 }
683 operands.push_back(argID);
684 }
685 if (failed(emitDebugLine(functionHeader, op.getLoc())))
686 return failure();
687 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
688 for (auto attr : op->getAttrs()) {
689 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
690 return attr.getName() == elided;
691 })) {
692 continue;
693 }
694 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
695 return failure();
696 }
697 }
698 return success();
699}
700
701LogicalResult
702Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
703 // Get TypeID.
704 uint32_t resultTypeID = 0;
705 SmallVector<StringRef, 4> elidedAttrs;
706 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
707 return failure();
708 }
709
710 elidedAttrs.push_back("type");
711 SmallVector<uint32_t, 4> operands;
712 operands.push_back(resultTypeID);
713 auto resultID = getNextID();
714
715 // Encode the name.
716 auto varName = varOp.getSymName();
717 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
718 if (failed(processName(resultID, varName))) {
719 return failure();
720 }
721 globalVarIDMap[varName] = resultID;
722 operands.push_back(resultID);
723
724 // Encode StorageClass.
725 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
726
727 // Encode initialization.
728 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
729 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
730 uint32_t initializerID = 0;
731 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
733 varOp->getParentOp(), initRef.getAttr());
734
735 // Check if initializer is GlobalVariable or SpecConstant* cases.
736 if (isa<spirv::GlobalVariableOp>(initOp))
737 initializerID = getVariableID(*initSymbolName);
738 else
739 initializerID = getSpecConstID(*initSymbolName);
740
741 if (!initializerID)
742 return emitError(varOp.getLoc(),
743 "invalid usage of undefined variable as initializer");
744
745 operands.push_back(initializerID);
746 elidedAttrs.push_back(initAttrName);
747 }
748
749 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
750 return failure();
751 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
752 elidedAttrs.push_back(initAttrName);
753
754 // Encode decorations.
755 for (auto attr : varOp->getAttrs()) {
756 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
757 return attr.getName() == elided;
758 })) {
759 continue;
760 }
761 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
762 return failure();
763 }
764 }
765 return success();
766}
767
768LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
769 // Assign <id>s to all blocks so that branches inside the SelectionOp can
770 // resolve properly.
771 auto &body = selectionOp.getBody();
772 for (Block &block : body)
773 getOrCreateBlockID(&block);
774
775 auto *headerBlock = selectionOp.getHeaderBlock();
776 auto *mergeBlock = selectionOp.getMergeBlock();
777 auto headerID = getBlockID(headerBlock);
778 auto mergeID = getBlockID(mergeBlock);
779 auto loc = selectionOp.getLoc();
780
781 // Before we do anything replace results of the selection operation with
782 // values yielded (with `mlir.merge`) from inside the region. The selection op
783 // is being flattened so we do not have to worry about values being defined
784 // inside a region and used outside it anymore.
785 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
786 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
787 for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
788 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
789
790 // This SelectionOp is in some MLIR block with preceding and following ops. In
791 // the binary format, it should reside in separate SPIR-V blocks from its
792 // preceding and following ops. So we need to emit unconditional branches to
793 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
794 // flow afterwards.
795 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
796
797 // Emit the selection header block, which dominates all other blocks, first.
798 // We need to emit an OpSelectionMerge instruction before the selection header
799 // block's terminator.
800 auto emitSelectionMerge = [&]() {
801 if (failed(emitDebugLine(functionBody, loc)))
802 return failure();
803 lastProcessedWasMergeInst = true;
805 functionBody, spirv::Opcode::OpSelectionMerge,
806 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
807 return success();
808 };
809 if (failed(
810 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
811 return failure();
812
813 // Process all blocks with a depth-first visitor starting from the header
814 // block. The selection header block and merge block are skipped by this
815 // visitor.
817 headerBlock, [&](Block *block) { return processBlock(block); },
818 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
819 return failure();
820
821 // There is nothing to do for the merge block in the selection, which just
822 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
823 // instruction to start a new SPIR-V block for ops following this SelectionOp.
824 // The block should use the <id> for the merge block.
825 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
826
827 // We do not process the mergeBlock but we still need to generate phi
828 // functions from its block arguments.
829 if (failed(emitPhiForBlockArguments(mergeBlock)))
830 return failure();
831
832 LLVM_DEBUG(llvm::dbgs() << "done merge ");
833 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
834 LLVM_DEBUG(llvm::dbgs() << "\n");
835 return success();
836}
837
838LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
839 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
840 // properly. We don't need to assign for the entry block, which is just for
841 // satisfying MLIR region's structural requirement.
842 auto &body = loopOp.getBody();
843 for (Block &block : llvm::drop_begin(body))
844 getOrCreateBlockID(&block);
845
846 auto *headerBlock = loopOp.getHeaderBlock();
847 auto *continueBlock = loopOp.getContinueBlock();
848 auto *mergeBlock = loopOp.getMergeBlock();
849 auto headerID = getBlockID(headerBlock);
850 auto continueID = getBlockID(continueBlock);
851 auto mergeID = getBlockID(mergeBlock);
852 auto loc = loopOp.getLoc();
853
854 // Before we do anything replace results of the selection operation with
855 // values yielded (with `mlir.merge`) from inside the region.
856 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
857 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
858 for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
859 loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
860
861 // This LoopOp is in some MLIR block with preceding and following ops. In the
862 // binary format, it should reside in separate SPIR-V blocks from its
863 // preceding and following ops. So we need to emit unconditional branches to
864 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
865 // afterwards.
866 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
867
868 // LoopOp's entry block is just there for satisfying MLIR's structural
869 // requirements so we omit it and start serialization from the loop header
870 // block.
871
872 // Emit the loop header block, which dominates all other blocks, first. We
873 // need to emit an OpLoopMerge instruction before the loop header block's
874 // terminator.
875 auto emitLoopMerge = [&]() {
876 if (failed(emitDebugLine(functionBody, loc)))
877 return failure();
878 lastProcessedWasMergeInst = true;
880 functionBody, spirv::Opcode::OpLoopMerge,
881 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
882 return success();
883 };
884 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
885 return failure();
886
887 // Process all blocks with a depth-first visitor starting from the header
888 // block. The loop header block, loop continue block, and loop merge block are
889 // skipped by this visitor and handled later in this function.
891 headerBlock, [&](Block *block) { return processBlock(block); },
892 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
893 return failure();
894
895 // We have handled all other blocks. Now get to the loop continue block.
896 if (failed(processBlock(continueBlock)))
897 return failure();
898
899 // There is nothing to do for the merge block in the loop, which just contains
900 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
901 // to start a new SPIR-V block for ops following this LoopOp. The block should
902 // use the <id> for the merge block.
903 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
904 LLVM_DEBUG(llvm::dbgs() << "done merge ");
905 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
906 LLVM_DEBUG(llvm::dbgs() << "\n");
907 return success();
908}
909
910LogicalResult Serializer::processBranchConditionalOp(
911 spirv::BranchConditionalOp condBranchOp) {
912 auto conditionID = getValueID(condBranchOp.getCondition());
913 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
914 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
915 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
917 if (auto weights = condBranchOp.getBranchWeights()) {
918 for (auto val : weights->getValue())
919 arguments.push_back(cast<IntegerAttr>(val).getInt());
920 }
922 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
923 return failure();
924 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
925 arguments);
926 return success();
927}
928
929LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
930 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
931 return failure();
932 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
933 {getOrCreateBlockID(branchOp.getTarget())});
934 return success();
935}
937LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
938 uint32_t selectorID = getValueID(switchOp.getSelector());
939 uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
940 SmallVector<uint32_t> arguments{selectorID, defaultLabelID};
942 std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
943 BlockRange targets = switchOp.getTargets();
944 if (literals) {
945 for (auto [literal, target] : llvm::zip_equal(*literals, targets)) {
946 arguments.push_back(literal.getLimitedValue());
947 uint32_t targetLabelID = getOrCreateBlockID(target);
948 arguments.push_back(targetLabelID);
949 }
950 }
952 if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
953 return failure();
954 encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
955 return success();
957
958LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
959 auto varName = addressOfOp.getVariable();
960 auto variableID = getVariableID(varName);
961 if (!variableID) {
962 return addressOfOp.emitError("unknown result <id> for variable ")
963 << varName;
964 }
965 valueIDMap[addressOfOp.getPointer()] = variableID;
966 return success();
967}
968
969LogicalResult
970Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
971 auto constName = referenceOfOp.getSpecConst();
972 auto constID = getSpecConstID(constName);
973 if (!constID) {
974 return referenceOfOp.emitError(
975 "unknown result <id> for specialization constant ")
976 << constName;
977 }
978 valueIDMap[referenceOfOp.getReference()] = constID;
979 return success();
980}
982template <>
983LogicalResult
984Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
986 // Add the ExecutionModel.
987 operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
988 // Add the function <id>.
989 auto funcID = getFunctionID(op.getFn());
990 if (!funcID) {
991 return op.emitError("missing <id> for function ")
992 << op.getFn()
993 << "; function needs to be defined before spirv.EntryPoint is "
994 "serialized";
995 }
996 operands.push_back(funcID);
997 // Add the name of the function.
998 spirv::encodeStringLiteralInto(operands, op.getFn());
999
1000 // Add the interface values.
1001 if (auto interface = op.getInterface()) {
1002 for (auto var : interface.getValue()) {
1003 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
1004 if (!id) {
1005 return op.emitError(
1006 "referencing undefined global variable."
1007 "spirv.EntryPoint is at the end of spirv.module. All "
1008 "referenced variables should already be defined");
1009 }
1010 operands.push_back(id);
1012 }
1013 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
1014 return success();
1015}
1017template <>
1018LogicalResult
1019Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
1020 SmallVector<uint32_t, 4> operands;
1021 // Add the function <id>.
1022 auto funcID = getFunctionID(op.getFn());
1023 if (!funcID) {
1024 return op.emitError("missing <id> for function ")
1025 << op.getFn()
1026 << "; function needs to be serialized before ExecutionModeOp is "
1027 "serialized";
1028 }
1029 operands.push_back(funcID);
1030 // Add the ExecutionMode.
1031 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
1032
1033 // Serialize values if any.
1034 auto values = op.getValues();
1035 if (values) {
1036 for (auto &intVal : values.getValue()) {
1037 operands.push_back(static_cast<uint32_t>(
1038 cast<IntegerAttr>(intVal).getValue().getZExtValue()));
1039 }
1040 }
1041 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
1042 operands);
1043 return success();
1044}
1045
1046template <>
1047LogicalResult
1048Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
1049 SmallVector<uint32_t, 4> operands;
1050 // Add the function <id>.
1051 uint32_t funcID = getFunctionID(op.getFn());
1052 if (!funcID)
1053 return op.emitError("missing <id> for function ")
1054 << op.getFn()
1055 << "; function needs to be serialized before ExecutionModeIdOp is "
1056 "serialized";
1057
1058 operands.push_back(funcID);
1059 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
1060
1061 for (Attribute refVal : op.getValues().getValue()) {
1062 uint32_t id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
1063 if (!id)
1064 return op.emitError("unknown <id> for specialization constant ")
1065 << cast<FlatSymbolRefAttr>(refVal).getValue();
1067 operands.push_back(id);
1068 }
1069 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
1070 operands);
1071 return success();
1072}
1073
1074template <>
1075LogicalResult
1076Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
1077 auto funcName = op.getCallee();
1078 uint32_t resTypeID = 0;
1079
1080 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
1081 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
1082 return failure();
1083
1084 auto funcID = getOrCreateFunctionID(funcName);
1085 auto funcCallID = getNextID();
1086 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
1087
1088 for (auto value : op.getArguments()) {
1089 auto valueID = getValueID(value);
1090 assert(valueID && "cannot find a value for spirv.FunctionCall");
1091 operands.push_back(valueID);
1092 }
1093
1094 if (!isa<NoneType>(resultTy))
1095 valueIDMap[op.getResult(0)] = funcCallID;
1096
1097 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
1098 return success();
1099}
1100
1101template <>
1102LogicalResult
1103Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
1104 SmallVector<uint32_t, 4> operands;
1106
1107 for (Value operand : op->getOperands()) {
1108 auto id = getValueID(operand);
1109 assert(id && "use before def!");
1110 operands.push_back(id);
1111 }
1112
1113 StringAttr memoryAccess = op.getMemoryAccessAttrName();
1114 if (auto attr = op->getAttr(memoryAccess)) {
1115 operands.push_back(
1116 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
1117 }
1118
1119 elidedAttrs.push_back(memoryAccess.strref());
1121 StringAttr alignment = op.getAlignmentAttrName();
1122 if (auto attr = op->getAttr(alignment)) {
1123 operands.push_back(static_cast<uint32_t>(
1124 cast<IntegerAttr>(attr).getValue().getZExtValue()));
1126
1127 elidedAttrs.push_back(alignment.strref());
1128
1129 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
1130 if (auto attr = op->getAttr(sourceMemoryAccess)) {
1131 operands.push_back(
1132 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
1133 }
1134
1135 elidedAttrs.push_back(sourceMemoryAccess.strref());
1136
1137 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
1138 if (auto attr = op->getAttr(sourceAlignment)) {
1139 operands.push_back(static_cast<uint32_t>(
1140 cast<IntegerAttr>(attr).getValue().getZExtValue()));
1141 }
1142
1143 elidedAttrs.push_back(sourceAlignment.strref());
1144 if (failed(emitDebugLine(functionBody, op.getLoc())))
1145 return failure();
1146 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
1147
1148 return success();
1149}
1150template <>
1151LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
1152 spirv::GenericCastToPtrExplicitOp op) {
1153 SmallVector<uint32_t, 4> operands;
1154 Type resultTy;
1155 Location loc = op->getLoc();
1156 uint32_t resultTypeID = 0;
1157 uint32_t resultID = 0;
1158 resultTy = op->getResult(0).getType();
1159 if (failed(processType(loc, resultTy, resultTypeID)))
1160 return failure();
1161 operands.push_back(resultTypeID);
1162
1163 resultID = getNextID();
1164 operands.push_back(resultID);
1165 valueIDMap[op->getResult(0)] = resultID;
1166
1167 for (Value operand : op->getOperands())
1168 operands.push_back(getValueID(operand));
1169 spirv::StorageClass resultStorage =
1170 cast<spirv::PointerType>(resultTy).getStorageClass();
1171 operands.push_back(static_cast<uint32_t>(resultStorage));
1172 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
1173 operands);
1174 return success();
1176
1177// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
1178// various Serializer::processOp<...>() specializations.
1179#define GET_SERIALIZATION_FNS
1180#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
1181
1182} // namespace spirv
1183} // namespace mlir
return success()
ArrayAttr()
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
static void printBlock(llvm::raw_ostream &os, Block *block, OpPrintingFlags &flags)
Definition Unit.cpp:37
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This is a value defined by a result of an operation.
Definition Value.h:454
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr llvm::StringLiteral extDebugInfo
Extension set name for non-semantic graph debug info.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr StringRef attributeName()
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147