MLIR  20.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "spirv-serialization"
24 
25 using namespace mlir;
26 
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
44  function_ref<LogicalResult(Block *)> blockHandler,
45  bool skipHeader = false, BlockRange skipBlocks = {}) {
46  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 
49  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50  if (skipHeader && block == headerBlock)
51  continue;
52  if (failed(blockHandler(block)))
53  return failure();
54  }
55  return success();
56 }
57 
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61  if (auto resultID =
62  prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63  valueIDMap[op.getResult()] = resultID;
64  return success();
65  }
66  return failure();
67 }
68 
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70  if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
71  /*isSpec=*/true)) {
72  // Emit the OpDecorate instruction for SpecId.
73  if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
74  auto val = static_cast<uint32_t>(specID.getInt());
75  if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
76  return failure();
77  }
78 
79  specConstIDMap[op.getSymName()] = resultID;
80  return processName(resultID, op.getSymName());
81  }
82  return failure();
83 }
84 
85 LogicalResult
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
87  uint32_t typeID = 0;
88  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
89  return failure();
90  }
91 
92  auto resultID = getNextID();
93 
94  SmallVector<uint32_t, 8> operands;
95  operands.push_back(typeID);
96  operands.push_back(resultID);
97 
98  auto constituents = op.getConstituents();
99 
100  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101  auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
102 
103  auto constituentName = constituent.getValue();
104  auto constituentID = getSpecConstID(constituentName);
105 
106  if (!constituentID) {
107  return op.emitError("unknown result <id> for specialization constant ")
108  << constituentName;
109  }
110 
111  operands.push_back(constituentID);
112  }
113 
114  encodeInstructionInto(typesGlobalValues,
115  spirv::Opcode::OpSpecConstantComposite, operands);
116  specConstIDMap[op.getSymName()] = resultID;
117 
118  return processName(resultID, op.getSymName());
119 }
120 
121 LogicalResult
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
123  uint32_t typeID = 0;
124  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
125  return failure();
126  }
127 
128  auto resultID = getNextID();
129 
130  SmallVector<uint32_t, 8> operands;
131  operands.push_back(typeID);
132  operands.push_back(resultID);
133 
134  Block &block = op.getRegion().getBlocks().front();
135  Operation &enclosedOp = block.getOperations().front();
136 
137  std::string enclosedOpName;
138  llvm::raw_string_ostream rss(enclosedOpName);
139  rss << "Op" << enclosedOp.getName().stripDialect();
140  auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
141 
142  if (!enclosedOpcode) {
143  op.emitError("Couldn't find op code for op ")
144  << enclosedOp.getName().getStringRef();
145  return failure();
146  }
147 
148  operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
149 
150  // Append operands to the enclosed op to the list of operands.
151  for (Value operand : enclosedOp.getOperands()) {
152  uint32_t id = getValueID(operand);
153  assert(id && "use before def!");
154  operands.push_back(id);
155  }
156 
157  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
158  operands);
159  valueIDMap[op.getResult()] = resultID;
160 
161  return success();
162 }
163 
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165  auto undefType = op.getType();
166  auto &id = undefValIDMap[undefType];
167  if (!id) {
168  id = getNextID();
169  uint32_t typeID = 0;
170  if (failed(processType(op.getLoc(), undefType, typeID)))
171  return failure();
172  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
173  {typeID, id});
174  }
175  valueIDMap[op.getResult()] = id;
176  return success();
177 }
178 
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
180  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
181  uint32_t argTypeID = 0;
182  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
183  return failure();
184  }
185  auto argValueID = getNextID();
186 
187  // Process decoration attributes of arguments.
188  auto funcOp = cast<FunctionOpInterface>(*op);
189  for (auto argAttr : funcOp.getArgAttrs(idx)) {
190  if (argAttr.getName() != DecorationAttr::name)
191  continue;
192 
193  if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194  if (failed(processDecorationAttr(op->getLoc(), argValueID,
195  decAttr.getValue(), decAttr)))
196  return failure();
197  }
198  }
199 
200  valueIDMap[arg] = argValueID;
201  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
202  {argTypeID, argValueID});
203  }
204  return success();
205 }
206 
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208  LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
209  assert(functionHeader.empty() && functionBody.empty());
210 
211  uint32_t fnTypeID = 0;
212  // Generate type of the function.
213  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
214  return failure();
215 
216  // Add the function definition.
217  SmallVector<uint32_t, 4> operands;
218  uint32_t resTypeID = 0;
219  auto resultTypes = op.getFunctionType().getResults();
220  if (resultTypes.size() > 1) {
221  return op.emitError("cannot serialize function with multiple return types");
222  }
223  if (failed(processType(op.getLoc(),
224  (resultTypes.empty() ? getVoidType() : resultTypes[0]),
225  resTypeID))) {
226  return failure();
227  }
228  operands.push_back(resTypeID);
229  auto funcID = getOrCreateFunctionID(op.getName());
230  operands.push_back(funcID);
231  operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
232  operands.push_back(fnTypeID);
233  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
234 
235  // Add function name.
236  if (failed(processName(funcID, op.getName()))) {
237  return failure();
238  }
239  // Handle external functions with linkage_attributes(LinkageAttributes)
240  // differently.
241  auto linkageAttr = op.getLinkageAttributes();
242  auto hasImportLinkage =
243  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244  spirv::LinkageType::Import);
245  if (op.isExternal() && !hasImportLinkage) {
246  return op.emitError(
247  "'spirv.module' cannot contain external functions "
248  "without 'Import' linkage_attributes (LinkageAttributes)");
249  }
250  if (op.isExternal() && hasImportLinkage) {
251  // Add an entry block to set up the block arguments
252  // to match the signature of the function.
253  // This is to generate OpFunctionParameter for functions with
254  // LinkageAttributes.
255  // WARNING: This operation has side-effect, it essentially adds a body
256  // to the func. Hence, making it not external anymore (isExternal()
257  // is going to return false for this function from now on)
258  // Hence, we'll remove the body once we are done with the serialization.
259  op.addEntryBlock();
260  if (failed(processFuncParameter(op)))
261  return failure();
262  // Don't need to process the added block, there is nothing to process,
263  // the fake body was added just to get the arguments, remove the body,
264  // since it's use is done.
265  op.eraseBody();
266  } else {
267  if (failed(processFuncParameter(op)))
268  return failure();
269 
270  // Some instructions (e.g., OpVariable) in a function must be in the first
271  // block in the function. These instructions will be put in
272  // functionHeader. Thus, we put the label in functionHeader first, and
273  // omit it from the first block. OpLabel only needs to be added for
274  // functions with body (including empty body). Since, we added a fake body
275  // for functions with 'Import' Linkage attributes, these functions are
276  // essentially function delcaration, so they should not have OpLabel and a
277  // terminating instruction. That's why we skipped it for those functions.
278  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
279  {getOrCreateBlockID(&op.front())});
280  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
281  return failure();
282  if (failed(visitInPrettyBlockOrder(
283  &op.front(), [&](Block *block) { return processBlock(block); },
284  /*skipHeader=*/true))) {
285  return failure();
286  }
287 
288  // There might be OpPhi instructions who have value references needing to
289  // fix.
290  for (const auto &deferredValue : deferredPhiValues) {
291  Value value = deferredValue.first;
292  uint32_t id = getValueID(value);
293  LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
294  << " to id = " << id << '\n');
295  assert(id && "OpPhi references undefined value!");
296  for (size_t offset : deferredValue.second)
297  functionBody[offset] = id;
298  }
299  deferredPhiValues.clear();
300  }
301  LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
302  << "' --\n");
303  // Insert Decorations based on Function Attributes.
304  // Only attributes we should be considering for decoration are the
305  // ::mlir::spirv::Decoration attributes.
306 
307  for (auto attr : op->getAttrs()) {
308  // Only generate OpDecorate op for spirv::Decoration attributes.
309  auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310  llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
311  /*capitalizeFirst=*/true));
312  if (isValidDecoration != std::nullopt) {
313  if (failed(processDecoration(op.getLoc(), funcID, attr))) {
314  return failure();
315  }
316  }
317  }
318  // Insert OpFunctionEnd.
319  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
320 
321  functions.append(functionHeader.begin(), functionHeader.end());
322  functions.append(functionBody.begin(), functionBody.end());
323  functionHeader.clear();
324  functionBody.clear();
325 
326  return success();
327 }
328 
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
330  SmallVector<uint32_t, 4> operands;
331  SmallVector<StringRef, 2> elidedAttrs;
332  uint32_t resultID = 0;
333  uint32_t resultTypeID = 0;
334  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
335  return failure();
336  }
337  operands.push_back(resultTypeID);
338  resultID = getNextID();
339  valueIDMap[op.getResult()] = resultID;
340  operands.push_back(resultID);
341  auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
342  if (attr) {
343  operands.push_back(
344  static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
345  }
346  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347  for (auto arg : op.getODSOperands(0)) {
348  auto argID = getValueID(arg);
349  if (!argID) {
350  return emitError(op.getLoc(), "operand 0 has a use before def");
351  }
352  operands.push_back(argID);
353  }
354  if (failed(emitDebugLine(functionHeader, op.getLoc())))
355  return failure();
356  encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
357  for (auto attr : op->getAttrs()) {
358  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359  return attr.getName() == elided;
360  })) {
361  continue;
362  }
363  if (failed(processDecoration(op.getLoc(), resultID, attr))) {
364  return failure();
365  }
366  }
367  return success();
368 }
369 
370 LogicalResult
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
372  // Get TypeID.
373  uint32_t resultTypeID = 0;
374  SmallVector<StringRef, 4> elidedAttrs;
375  if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
376  return failure();
377  }
378 
379  elidedAttrs.push_back("type");
380  SmallVector<uint32_t, 4> operands;
381  operands.push_back(resultTypeID);
382  auto resultID = getNextID();
383 
384  // Encode the name.
385  auto varName = varOp.getSymName();
386  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
387  if (failed(processName(resultID, varName))) {
388  return failure();
389  }
390  globalVarIDMap[varName] = resultID;
391  operands.push_back(resultID);
392 
393  // Encode StorageClass.
394  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
395 
396  // Encode initialization.
397  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399  uint32_t initializerID = 0;
400  auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
402  varOp->getParentOp(), initRef.getAttr());
403 
404  // Check if initializer is GlobalVariable or SpecConstant* cases.
405  if (isa<spirv::GlobalVariableOp>(initOp))
406  initializerID = getVariableID(*initSymbolName);
407  else
408  initializerID = getSpecConstID(*initSymbolName);
409 
410  if (!initializerID)
411  return emitError(varOp.getLoc(),
412  "invalid usage of undefined variable as initializer");
413 
414  operands.push_back(initializerID);
415  elidedAttrs.push_back(initAttrName);
416  }
417 
418  if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
419  return failure();
420  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
421  elidedAttrs.push_back(initAttrName);
422 
423  // Encode decorations.
424  for (auto attr : varOp->getAttrs()) {
425  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426  return attr.getName() == elided;
427  })) {
428  continue;
429  }
430  if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
431  return failure();
432  }
433  }
434  return success();
435 }
436 
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
438  // Assign <id>s to all blocks so that branches inside the SelectionOp can
439  // resolve properly.
440  auto &body = selectionOp.getBody();
441  for (Block &block : body)
442  getOrCreateBlockID(&block);
443 
444  auto *headerBlock = selectionOp.getHeaderBlock();
445  auto *mergeBlock = selectionOp.getMergeBlock();
446  auto headerID = getBlockID(headerBlock);
447  auto mergeID = getBlockID(mergeBlock);
448  auto loc = selectionOp.getLoc();
449 
450  // This SelectionOp is in some MLIR block with preceding and following ops. In
451  // the binary format, it should reside in separate SPIR-V blocks from its
452  // preceding and following ops. So we need to emit unconditional branches to
453  // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
454  // flow afterwards.
455  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
456 
457  // Emit the selection header block, which dominates all other blocks, first.
458  // We need to emit an OpSelectionMerge instruction before the selection header
459  // block's terminator.
460  auto emitSelectionMerge = [&]() {
461  if (failed(emitDebugLine(functionBody, loc)))
462  return failure();
463  lastProcessedWasMergeInst = true;
465  functionBody, spirv::Opcode::OpSelectionMerge,
466  {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
467  return success();
468  };
469  if (failed(
470  processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
471  return failure();
472 
473  // Process all blocks with a depth-first visitor starting from the header
474  // block. The selection header block and merge block are skipped by this
475  // visitor.
476  if (failed(visitInPrettyBlockOrder(
477  headerBlock, [&](Block *block) { return processBlock(block); },
478  /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
479  return failure();
480 
481  // There is nothing to do for the merge block in the selection, which just
482  // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
483  // instruction to start a new SPIR-V block for ops following this SelectionOp.
484  // The block should use the <id> for the merge block.
485  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
486  LLVM_DEBUG(llvm::dbgs() << "done merge ");
487  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488  LLVM_DEBUG(llvm::dbgs() << "\n");
489  return success();
490 }
491 
492 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
493  // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
494  // properly. We don't need to assign for the entry block, which is just for
495  // satisfying MLIR region's structural requirement.
496  auto &body = loopOp.getBody();
497  for (Block &block : llvm::drop_begin(body))
498  getOrCreateBlockID(&block);
499 
500  auto *headerBlock = loopOp.getHeaderBlock();
501  auto *continueBlock = loopOp.getContinueBlock();
502  auto *mergeBlock = loopOp.getMergeBlock();
503  auto headerID = getBlockID(headerBlock);
504  auto continueID = getBlockID(continueBlock);
505  auto mergeID = getBlockID(mergeBlock);
506  auto loc = loopOp.getLoc();
507 
508  // This LoopOp is in some MLIR block with preceding and following ops. In the
509  // binary format, it should reside in separate SPIR-V blocks from its
510  // preceding and following ops. So we need to emit unconditional branches to
511  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
512  // afterwards.
513  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
514 
515  // LoopOp's entry block is just there for satisfying MLIR's structural
516  // requirements so we omit it and start serialization from the loop header
517  // block.
518 
519  // Emit the loop header block, which dominates all other blocks, first. We
520  // need to emit an OpLoopMerge instruction before the loop header block's
521  // terminator.
522  auto emitLoopMerge = [&]() {
523  if (failed(emitDebugLine(functionBody, loc)))
524  return failure();
525  lastProcessedWasMergeInst = true;
527  functionBody, spirv::Opcode::OpLoopMerge,
528  {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
529  return success();
530  };
531  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
532  return failure();
533 
534  // Process all blocks with a depth-first visitor starting from the header
535  // block. The loop header block, loop continue block, and loop merge block are
536  // skipped by this visitor and handled later in this function.
537  if (failed(visitInPrettyBlockOrder(
538  headerBlock, [&](Block *block) { return processBlock(block); },
539  /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
540  return failure();
541 
542  // We have handled all other blocks. Now get to the loop continue block.
543  if (failed(processBlock(continueBlock)))
544  return failure();
545 
546  // There is nothing to do for the merge block in the loop, which just contains
547  // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
548  // to start a new SPIR-V block for ops following this LoopOp. The block should
549  // use the <id> for the merge block.
550  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
551  LLVM_DEBUG(llvm::dbgs() << "done merge ");
552  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
553  LLVM_DEBUG(llvm::dbgs() << "\n");
554  return success();
555 }
556 
557 LogicalResult Serializer::processBranchConditionalOp(
558  spirv::BranchConditionalOp condBranchOp) {
559  auto conditionID = getValueID(condBranchOp.getCondition());
560  auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
561  auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
562  SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
563 
564  if (auto weights = condBranchOp.getBranchWeights()) {
565  for (auto val : weights->getValue())
566  arguments.push_back(cast<IntegerAttr>(val).getInt());
567  }
568 
569  if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
570  return failure();
571  encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
572  arguments);
573  return success();
574 }
575 
576 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
577  if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
578  return failure();
579  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
580  {getOrCreateBlockID(branchOp.getTarget())});
581  return success();
582 }
583 
584 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
585  auto varName = addressOfOp.getVariable();
586  auto variableID = getVariableID(varName);
587  if (!variableID) {
588  return addressOfOp.emitError("unknown result <id> for variable ")
589  << varName;
590  }
591  valueIDMap[addressOfOp.getPointer()] = variableID;
592  return success();
593 }
594 
595 LogicalResult
596 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
597  auto constName = referenceOfOp.getSpecConst();
598  auto constID = getSpecConstID(constName);
599  if (!constID) {
600  return referenceOfOp.emitError(
601  "unknown result <id> for specialization constant ")
602  << constName;
603  }
604  valueIDMap[referenceOfOp.getReference()] = constID;
605  return success();
606 }
607 
608 template <>
609 LogicalResult
610 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
611  SmallVector<uint32_t, 4> operands;
612  // Add the ExecutionModel.
613  operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
614  // Add the function <id>.
615  auto funcID = getFunctionID(op.getFn());
616  if (!funcID) {
617  return op.emitError("missing <id> for function ")
618  << op.getFn()
619  << "; function needs to be defined before spirv.EntryPoint is "
620  "serialized";
621  }
622  operands.push_back(funcID);
623  // Add the name of the function.
624  spirv::encodeStringLiteralInto(operands, op.getFn());
625 
626  // Add the interface values.
627  if (auto interface = op.getInterface()) {
628  for (auto var : interface.getValue()) {
629  auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
630  if (!id) {
631  return op.emitError(
632  "referencing undefined global variable."
633  "spirv.EntryPoint is at the end of spirv.module. All "
634  "referenced variables should already be defined");
635  }
636  operands.push_back(id);
637  }
638  }
639  encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
640  return success();
641 }
642 
643 template <>
644 LogicalResult
645 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
646  SmallVector<uint32_t, 4> operands;
647  // Add the function <id>.
648  auto funcID = getFunctionID(op.getFn());
649  if (!funcID) {
650  return op.emitError("missing <id> for function ")
651  << op.getFn()
652  << "; function needs to be serialized before ExecutionModeOp is "
653  "serialized";
654  }
655  operands.push_back(funcID);
656  // Add the ExecutionMode.
657  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
658 
659  // Serialize values if any.
660  auto values = op.getValues();
661  if (values) {
662  for (auto &intVal : values.getValue()) {
663  operands.push_back(static_cast<uint32_t>(
664  llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
665  }
666  }
667  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
668  operands);
669  return success();
670 }
671 
672 template <>
673 LogicalResult
674 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
675  auto funcName = op.getCallee();
676  uint32_t resTypeID = 0;
677 
678  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
679  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
680  return failure();
681 
682  auto funcID = getOrCreateFunctionID(funcName);
683  auto funcCallID = getNextID();
684  SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
685 
686  for (auto value : op.getArguments()) {
687  auto valueID = getValueID(value);
688  assert(valueID && "cannot find a value for spirv.FunctionCall");
689  operands.push_back(valueID);
690  }
691 
692  if (!isa<NoneType>(resultTy))
693  valueIDMap[op.getResult(0)] = funcCallID;
694 
695  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
696  return success();
697 }
698 
699 template <>
700 LogicalResult
701 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
702  SmallVector<uint32_t, 4> operands;
703  SmallVector<StringRef, 2> elidedAttrs;
704 
705  for (Value operand : op->getOperands()) {
706  auto id = getValueID(operand);
707  assert(id && "use before def!");
708  operands.push_back(id);
709  }
710 
711  StringAttr memoryAccess = op.getMemoryAccessAttrName();
712  if (auto attr = op->getAttr(memoryAccess)) {
713  operands.push_back(
714  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
715  }
716 
717  elidedAttrs.push_back(memoryAccess.strref());
718 
719  StringAttr alignment = op.getAlignmentAttrName();
720  if (auto attr = op->getAttr(alignment)) {
721  operands.push_back(static_cast<uint32_t>(
722  cast<IntegerAttr>(attr).getValue().getZExtValue()));
723  }
724 
725  elidedAttrs.push_back(alignment.strref());
726 
727  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
728  if (auto attr = op->getAttr(sourceMemoryAccess)) {
729  operands.push_back(
730  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
731  }
732 
733  elidedAttrs.push_back(sourceMemoryAccess.strref());
734 
735  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
736  if (auto attr = op->getAttr(sourceAlignment)) {
737  operands.push_back(static_cast<uint32_t>(
738  cast<IntegerAttr>(attr).getValue().getZExtValue()));
739  }
740 
741  elidedAttrs.push_back(sourceAlignment.strref());
742  if (failed(emitDebugLine(functionBody, op.getLoc())))
743  return failure();
744  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
745 
746  return success();
747 }
748 template <>
749 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
750  spirv::GenericCastToPtrExplicitOp op) {
751  SmallVector<uint32_t, 4> operands;
752  Type resultTy;
753  Location loc = op->getLoc();
754  uint32_t resultTypeID = 0;
755  uint32_t resultID = 0;
756  resultTy = op->getResult(0).getType();
757  if (failed(processType(loc, resultTy, resultTypeID)))
758  return failure();
759  operands.push_back(resultTypeID);
760 
761  resultID = getNextID();
762  operands.push_back(resultID);
763  valueIDMap[op->getResult(0)] = resultID;
764 
765  for (Value operand : op->getOperands())
766  operands.push_back(getValueID(operand));
767  spirv::StorageClass resultStorage =
768  cast<spirv::PointerType>(resultTy).getStorageClass();
769  operands.push_back(static_cast<uint32_t>(resultStorage));
770  encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
771  operands);
772  return success();
773 }
774 
775 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
776 // various Serializer::processOp<...>() specializations.
777 #define GET_SERIALIZATION_FNS
778 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
779 
780 } // namespace spirv
781 } // namespace mlir
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType & getOperations()
Definition: Block.h:135
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_iterator result_type_begin()
Definition: Operation.h:421
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
BlockListType & getBlocks()
Definition: Region.h:45
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.