MLIR  21.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "spirv-serialization"
24 
25 using namespace mlir;
26 
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
44  function_ref<LogicalResult(Block *)> blockHandler,
45  bool skipHeader = false, BlockRange skipBlocks = {}) {
46  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 
49  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50  if (skipHeader && block == headerBlock)
51  continue;
52  if (failed(blockHandler(block)))
53  return failure();
54  }
55  return success();
56 }
57 
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61  if (auto resultID =
62  prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63  valueIDMap[op.getResult()] = resultID;
64  return success();
65  }
66  return failure();
67 }
68 
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70  if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
71  /*isSpec=*/true)) {
72  // Emit the OpDecorate instruction for SpecId.
73  if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
74  auto val = static_cast<uint32_t>(specID.getInt());
75  if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
76  return failure();
77  }
78 
79  specConstIDMap[op.getSymName()] = resultID;
80  return processName(resultID, op.getSymName());
81  }
82  return failure();
83 }
84 
85 LogicalResult
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
87  uint32_t typeID = 0;
88  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
89  return failure();
90  }
91 
92  auto resultID = getNextID();
93 
94  SmallVector<uint32_t, 8> operands;
95  operands.push_back(typeID);
96  operands.push_back(resultID);
97 
98  auto constituents = op.getConstituents();
99 
100  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101  auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
102 
103  auto constituentName = constituent.getValue();
104  auto constituentID = getSpecConstID(constituentName);
105 
106  if (!constituentID) {
107  return op.emitError("unknown result <id> for specialization constant ")
108  << constituentName;
109  }
110 
111  operands.push_back(constituentID);
112  }
113 
114  encodeInstructionInto(typesGlobalValues,
115  spirv::Opcode::OpSpecConstantComposite, operands);
116  specConstIDMap[op.getSymName()] = resultID;
117 
118  return processName(resultID, op.getSymName());
119 }
120 
121 LogicalResult
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
123  uint32_t typeID = 0;
124  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
125  return failure();
126  }
127 
128  auto resultID = getNextID();
129 
130  SmallVector<uint32_t, 8> operands;
131  operands.push_back(typeID);
132  operands.push_back(resultID);
133 
134  Block &block = op.getRegion().getBlocks().front();
135  Operation &enclosedOp = block.getOperations().front();
136 
137  std::string enclosedOpName;
138  llvm::raw_string_ostream rss(enclosedOpName);
139  rss << "Op" << enclosedOp.getName().stripDialect();
140  auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
141 
142  if (!enclosedOpcode) {
143  op.emitError("Couldn't find op code for op ")
144  << enclosedOp.getName().getStringRef();
145  return failure();
146  }
147 
148  operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
149 
150  // Append operands to the enclosed op to the list of operands.
151  for (Value operand : enclosedOp.getOperands()) {
152  uint32_t id = getValueID(operand);
153  assert(id && "use before def!");
154  operands.push_back(id);
155  }
156 
157  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
158  operands);
159  valueIDMap[op.getResult()] = resultID;
160 
161  return success();
162 }
163 
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165  auto undefType = op.getType();
166  auto &id = undefValIDMap[undefType];
167  if (!id) {
168  id = getNextID();
169  uint32_t typeID = 0;
170  if (failed(processType(op.getLoc(), undefType, typeID)))
171  return failure();
172  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
173  {typeID, id});
174  }
175  valueIDMap[op.getResult()] = id;
176  return success();
177 }
178 
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
180  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
181  uint32_t argTypeID = 0;
182  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
183  return failure();
184  }
185  auto argValueID = getNextID();
186 
187  // Process decoration attributes of arguments.
188  auto funcOp = cast<FunctionOpInterface>(*op);
189  for (auto argAttr : funcOp.getArgAttrs(idx)) {
190  if (argAttr.getName() != DecorationAttr::name)
191  continue;
192 
193  if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194  if (failed(processDecorationAttr(op->getLoc(), argValueID,
195  decAttr.getValue(), decAttr)))
196  return failure();
197  }
198  }
199 
200  valueIDMap[arg] = argValueID;
201  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
202  {argTypeID, argValueID});
203  }
204  return success();
205 }
206 
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208  LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
209  assert(functionHeader.empty() && functionBody.empty());
210 
211  uint32_t fnTypeID = 0;
212  // Generate type of the function.
213  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
214  return failure();
215 
216  // Add the function definition.
217  SmallVector<uint32_t, 4> operands;
218  uint32_t resTypeID = 0;
219  auto resultTypes = op.getFunctionType().getResults();
220  if (resultTypes.size() > 1) {
221  return op.emitError("cannot serialize function with multiple return types");
222  }
223  if (failed(processType(op.getLoc(),
224  (resultTypes.empty() ? getVoidType() : resultTypes[0]),
225  resTypeID))) {
226  return failure();
227  }
228  operands.push_back(resTypeID);
229  auto funcID = getOrCreateFunctionID(op.getName());
230  operands.push_back(funcID);
231  operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
232  operands.push_back(fnTypeID);
233  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
234 
235  // Add function name.
236  if (failed(processName(funcID, op.getName()))) {
237  return failure();
238  }
239  // Handle external functions with linkage_attributes(LinkageAttributes)
240  // differently.
241  auto linkageAttr = op.getLinkageAttributes();
242  auto hasImportLinkage =
243  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244  spirv::LinkageType::Import);
245  if (op.isExternal() && !hasImportLinkage) {
246  return op.emitError(
247  "'spirv.module' cannot contain external functions "
248  "without 'Import' linkage_attributes (LinkageAttributes)");
249  }
250  if (op.isExternal() && hasImportLinkage) {
251  // Add an entry block to set up the block arguments
252  // to match the signature of the function.
253  // This is to generate OpFunctionParameter for functions with
254  // LinkageAttributes.
255  // WARNING: This operation has side-effect, it essentially adds a body
256  // to the func. Hence, making it not external anymore (isExternal()
257  // is going to return false for this function from now on)
258  // Hence, we'll remove the body once we are done with the serialization.
259  op.addEntryBlock();
260  if (failed(processFuncParameter(op)))
261  return failure();
262  // Don't need to process the added block, there is nothing to process,
263  // the fake body was added just to get the arguments, remove the body,
264  // since it's use is done.
265  op.eraseBody();
266  } else {
267  if (failed(processFuncParameter(op)))
268  return failure();
269 
270  // Some instructions (e.g., OpVariable) in a function must be in the first
271  // block in the function. These instructions will be put in
272  // functionHeader. Thus, we put the label in functionHeader first, and
273  // omit it from the first block. OpLabel only needs to be added for
274  // functions with body (including empty body). Since, we added a fake body
275  // for functions with 'Import' Linkage attributes, these functions are
276  // essentially function delcaration, so they should not have OpLabel and a
277  // terminating instruction. That's why we skipped it for those functions.
278  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
279  {getOrCreateBlockID(&op.front())});
280  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
281  return failure();
282  if (failed(visitInPrettyBlockOrder(
283  &op.front(), [&](Block *block) { return processBlock(block); },
284  /*skipHeader=*/true))) {
285  return failure();
286  }
287 
288  // There might be OpPhi instructions who have value references needing to
289  // fix.
290  for (const auto &deferredValue : deferredPhiValues) {
291  Value value = deferredValue.first;
292  uint32_t id = getValueID(value);
293  LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
294  << " to id = " << id << '\n');
295  assert(id && "OpPhi references undefined value!");
296  for (size_t offset : deferredValue.second)
297  functionBody[offset] = id;
298  }
299  deferredPhiValues.clear();
300  }
301  LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
302  << "' --\n");
303  // Insert Decorations based on Function Attributes.
304  // Only attributes we should be considering for decoration are the
305  // ::mlir::spirv::Decoration attributes.
306 
307  for (auto attr : op->getAttrs()) {
308  // Only generate OpDecorate op for spirv::Decoration attributes.
309  auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310  llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
311  /*capitalizeFirst=*/true));
312  if (isValidDecoration != std::nullopt) {
313  if (failed(processDecoration(op.getLoc(), funcID, attr))) {
314  return failure();
315  }
316  }
317  }
318  // Insert OpFunctionEnd.
319  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
320 
321  functions.append(functionHeader.begin(), functionHeader.end());
322  functions.append(functionBody.begin(), functionBody.end());
323  functionHeader.clear();
324  functionBody.clear();
325 
326  return success();
327 }
328 
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
330  SmallVector<uint32_t, 4> operands;
331  SmallVector<StringRef, 2> elidedAttrs;
332  uint32_t resultID = 0;
333  uint32_t resultTypeID = 0;
334  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
335  return failure();
336  }
337  operands.push_back(resultTypeID);
338  resultID = getNextID();
339  valueIDMap[op.getResult()] = resultID;
340  operands.push_back(resultID);
341  auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
342  if (attr) {
343  operands.push_back(
344  static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
345  }
346  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347  for (auto arg : op.getODSOperands(0)) {
348  auto argID = getValueID(arg);
349  if (!argID) {
350  return emitError(op.getLoc(), "operand 0 has a use before def");
351  }
352  operands.push_back(argID);
353  }
354  if (failed(emitDebugLine(functionHeader, op.getLoc())))
355  return failure();
356  encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
357  for (auto attr : op->getAttrs()) {
358  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359  return attr.getName() == elided;
360  })) {
361  continue;
362  }
363  if (failed(processDecoration(op.getLoc(), resultID, attr))) {
364  return failure();
365  }
366  }
367  return success();
368 }
369 
370 LogicalResult
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
372  // Get TypeID.
373  uint32_t resultTypeID = 0;
374  SmallVector<StringRef, 4> elidedAttrs;
375  if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
376  return failure();
377  }
378 
379  elidedAttrs.push_back("type");
380  SmallVector<uint32_t, 4> operands;
381  operands.push_back(resultTypeID);
382  auto resultID = getNextID();
383 
384  // Encode the name.
385  auto varName = varOp.getSymName();
386  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
387  if (failed(processName(resultID, varName))) {
388  return failure();
389  }
390  globalVarIDMap[varName] = resultID;
391  operands.push_back(resultID);
392 
393  // Encode StorageClass.
394  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
395 
396  // Encode initialization.
397  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399  uint32_t initializerID = 0;
400  auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
402  varOp->getParentOp(), initRef.getAttr());
403 
404  // Check if initializer is GlobalVariable or SpecConstant* cases.
405  if (isa<spirv::GlobalVariableOp>(initOp))
406  initializerID = getVariableID(*initSymbolName);
407  else
408  initializerID = getSpecConstID(*initSymbolName);
409 
410  if (!initializerID)
411  return emitError(varOp.getLoc(),
412  "invalid usage of undefined variable as initializer");
413 
414  operands.push_back(initializerID);
415  elidedAttrs.push_back(initAttrName);
416  }
417 
418  if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
419  return failure();
420  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
421  elidedAttrs.push_back(initAttrName);
422 
423  // Encode decorations.
424  for (auto attr : varOp->getAttrs()) {
425  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426  return attr.getName() == elided;
427  })) {
428  continue;
429  }
430  if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
431  return failure();
432  }
433  }
434  return success();
435 }
436 
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
438  // Assign <id>s to all blocks so that branches inside the SelectionOp can
439  // resolve properly.
440  auto &body = selectionOp.getBody();
441  for (Block &block : body)
442  getOrCreateBlockID(&block);
443 
444  auto *headerBlock = selectionOp.getHeaderBlock();
445  auto *mergeBlock = selectionOp.getMergeBlock();
446  auto headerID = getBlockID(headerBlock);
447  auto mergeID = getBlockID(mergeBlock);
448  auto loc = selectionOp.getLoc();
449 
450  // Before we do anything replace results of the selection operation with
451  // values yielded (with `mlir.merge`) from inside the region. The selection op
452  // is being flattened so we do not have to worry about values being defined
453  // inside a region and used outside it anymore.
454  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456  for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457  selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
458 
459  // This SelectionOp is in some MLIR block with preceding and following ops. In
460  // the binary format, it should reside in separate SPIR-V blocks from its
461  // preceding and following ops. So we need to emit unconditional branches to
462  // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
463  // flow afterwards.
464  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
465 
466  // Emit the selection header block, which dominates all other blocks, first.
467  // We need to emit an OpSelectionMerge instruction before the selection header
468  // block's terminator.
469  auto emitSelectionMerge = [&]() {
470  if (failed(emitDebugLine(functionBody, loc)))
471  return failure();
472  lastProcessedWasMergeInst = true;
474  functionBody, spirv::Opcode::OpSelectionMerge,
475  {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
476  return success();
477  };
478  if (failed(
479  processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
480  return failure();
481 
482  // Process all blocks with a depth-first visitor starting from the header
483  // block. The selection header block and merge block are skipped by this
484  // visitor.
485  if (failed(visitInPrettyBlockOrder(
486  headerBlock, [&](Block *block) { return processBlock(block); },
487  /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
488  return failure();
489 
490  // There is nothing to do for the merge block in the selection, which just
491  // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
492  // instruction to start a new SPIR-V block for ops following this SelectionOp.
493  // The block should use the <id> for the merge block.
494  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
495 
496  // We do not process the mergeBlock but we still need to generate phi
497  // functions from its block arguments.
498  if (failed(emitPhiForBlockArguments(mergeBlock)))
499  return failure();
500 
501  LLVM_DEBUG(llvm::dbgs() << "done merge ");
502  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
503  LLVM_DEBUG(llvm::dbgs() << "\n");
504  return success();
505 }
506 
507 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
508  // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
509  // properly. We don't need to assign for the entry block, which is just for
510  // satisfying MLIR region's structural requirement.
511  auto &body = loopOp.getBody();
512  for (Block &block : llvm::drop_begin(body))
513  getOrCreateBlockID(&block);
514 
515  auto *headerBlock = loopOp.getHeaderBlock();
516  auto *continueBlock = loopOp.getContinueBlock();
517  auto *mergeBlock = loopOp.getMergeBlock();
518  auto headerID = getBlockID(headerBlock);
519  auto continueID = getBlockID(continueBlock);
520  auto mergeID = getBlockID(mergeBlock);
521  auto loc = loopOp.getLoc();
522 
523  // Before we do anything replace results of the selection operation with
524  // values yielded (with `mlir.merge`) from inside the region.
525  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
526  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
527  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
528  loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
529 
530  // This LoopOp is in some MLIR block with preceding and following ops. In the
531  // binary format, it should reside in separate SPIR-V blocks from its
532  // preceding and following ops. So we need to emit unconditional branches to
533  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
534  // afterwards.
535  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
536 
537  // LoopOp's entry block is just there for satisfying MLIR's structural
538  // requirements so we omit it and start serialization from the loop header
539  // block.
540 
541  // Emit the loop header block, which dominates all other blocks, first. We
542  // need to emit an OpLoopMerge instruction before the loop header block's
543  // terminator.
544  auto emitLoopMerge = [&]() {
545  if (failed(emitDebugLine(functionBody, loc)))
546  return failure();
547  lastProcessedWasMergeInst = true;
549  functionBody, spirv::Opcode::OpLoopMerge,
550  {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
551  return success();
552  };
553  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
554  return failure();
555 
556  // Process all blocks with a depth-first visitor starting from the header
557  // block. The loop header block, loop continue block, and loop merge block are
558  // skipped by this visitor and handled later in this function.
559  if (failed(visitInPrettyBlockOrder(
560  headerBlock, [&](Block *block) { return processBlock(block); },
561  /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
562  return failure();
563 
564  // We have handled all other blocks. Now get to the loop continue block.
565  if (failed(processBlock(continueBlock)))
566  return failure();
567 
568  // There is nothing to do for the merge block in the loop, which just contains
569  // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
570  // to start a new SPIR-V block for ops following this LoopOp. The block should
571  // use the <id> for the merge block.
572  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
573  LLVM_DEBUG(llvm::dbgs() << "done merge ");
574  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
575  LLVM_DEBUG(llvm::dbgs() << "\n");
576  return success();
577 }
578 
579 LogicalResult Serializer::processBranchConditionalOp(
580  spirv::BranchConditionalOp condBranchOp) {
581  auto conditionID = getValueID(condBranchOp.getCondition());
582  auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
583  auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
584  SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
585 
586  if (auto weights = condBranchOp.getBranchWeights()) {
587  for (auto val : weights->getValue())
588  arguments.push_back(cast<IntegerAttr>(val).getInt());
589  }
590 
591  if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
592  return failure();
593  encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
594  arguments);
595  return success();
596 }
597 
598 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
599  if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
600  return failure();
601  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
602  {getOrCreateBlockID(branchOp.getTarget())});
603  return success();
604 }
605 
606 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
607  auto varName = addressOfOp.getVariable();
608  auto variableID = getVariableID(varName);
609  if (!variableID) {
610  return addressOfOp.emitError("unknown result <id> for variable ")
611  << varName;
612  }
613  valueIDMap[addressOfOp.getPointer()] = variableID;
614  return success();
615 }
616 
617 LogicalResult
618 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
619  auto constName = referenceOfOp.getSpecConst();
620  auto constID = getSpecConstID(constName);
621  if (!constID) {
622  return referenceOfOp.emitError(
623  "unknown result <id> for specialization constant ")
624  << constName;
625  }
626  valueIDMap[referenceOfOp.getReference()] = constID;
627  return success();
628 }
629 
630 template <>
631 LogicalResult
632 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
633  SmallVector<uint32_t, 4> operands;
634  // Add the ExecutionModel.
635  operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
636  // Add the function <id>.
637  auto funcID = getFunctionID(op.getFn());
638  if (!funcID) {
639  return op.emitError("missing <id> for function ")
640  << op.getFn()
641  << "; function needs to be defined before spirv.EntryPoint is "
642  "serialized";
643  }
644  operands.push_back(funcID);
645  // Add the name of the function.
646  spirv::encodeStringLiteralInto(operands, op.getFn());
647 
648  // Add the interface values.
649  if (auto interface = op.getInterface()) {
650  for (auto var : interface.getValue()) {
651  auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
652  if (!id) {
653  return op.emitError(
654  "referencing undefined global variable."
655  "spirv.EntryPoint is at the end of spirv.module. All "
656  "referenced variables should already be defined");
657  }
658  operands.push_back(id);
659  }
660  }
661  encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
662  return success();
663 }
664 
665 template <>
666 LogicalResult
667 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
668  SmallVector<uint32_t, 4> operands;
669  // Add the function <id>.
670  auto funcID = getFunctionID(op.getFn());
671  if (!funcID) {
672  return op.emitError("missing <id> for function ")
673  << op.getFn()
674  << "; function needs to be serialized before ExecutionModeOp is "
675  "serialized";
676  }
677  operands.push_back(funcID);
678  // Add the ExecutionMode.
679  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
680 
681  // Serialize values if any.
682  auto values = op.getValues();
683  if (values) {
684  for (auto &intVal : values.getValue()) {
685  operands.push_back(static_cast<uint32_t>(
686  llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
687  }
688  }
689  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
690  operands);
691  return success();
692 }
693 
694 template <>
695 LogicalResult
696 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
697  auto funcName = op.getCallee();
698  uint32_t resTypeID = 0;
699 
700  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
701  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
702  return failure();
703 
704  auto funcID = getOrCreateFunctionID(funcName);
705  auto funcCallID = getNextID();
706  SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
707 
708  for (auto value : op.getArguments()) {
709  auto valueID = getValueID(value);
710  assert(valueID && "cannot find a value for spirv.FunctionCall");
711  operands.push_back(valueID);
712  }
713 
714  if (!isa<NoneType>(resultTy))
715  valueIDMap[op.getResult(0)] = funcCallID;
716 
717  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
718  return success();
719 }
720 
721 template <>
722 LogicalResult
723 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
724  SmallVector<uint32_t, 4> operands;
725  SmallVector<StringRef, 2> elidedAttrs;
726 
727  for (Value operand : op->getOperands()) {
728  auto id = getValueID(operand);
729  assert(id && "use before def!");
730  operands.push_back(id);
731  }
732 
733  StringAttr memoryAccess = op.getMemoryAccessAttrName();
734  if (auto attr = op->getAttr(memoryAccess)) {
735  operands.push_back(
736  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
737  }
738 
739  elidedAttrs.push_back(memoryAccess.strref());
740 
741  StringAttr alignment = op.getAlignmentAttrName();
742  if (auto attr = op->getAttr(alignment)) {
743  operands.push_back(static_cast<uint32_t>(
744  cast<IntegerAttr>(attr).getValue().getZExtValue()));
745  }
746 
747  elidedAttrs.push_back(alignment.strref());
748 
749  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
750  if (auto attr = op->getAttr(sourceMemoryAccess)) {
751  operands.push_back(
752  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
753  }
754 
755  elidedAttrs.push_back(sourceMemoryAccess.strref());
756 
757  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
758  if (auto attr = op->getAttr(sourceAlignment)) {
759  operands.push_back(static_cast<uint32_t>(
760  cast<IntegerAttr>(attr).getValue().getZExtValue()));
761  }
762 
763  elidedAttrs.push_back(sourceAlignment.strref());
764  if (failed(emitDebugLine(functionBody, op.getLoc())))
765  return failure();
766  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
767 
768  return success();
769 }
770 template <>
771 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
772  spirv::GenericCastToPtrExplicitOp op) {
773  SmallVector<uint32_t, 4> operands;
774  Type resultTy;
775  Location loc = op->getLoc();
776  uint32_t resultTypeID = 0;
777  uint32_t resultID = 0;
778  resultTy = op->getResult(0).getType();
779  if (failed(processType(loc, resultTy, resultTypeID)))
780  return failure();
781  operands.push_back(resultTypeID);
782 
783  resultID = getNextID();
784  operands.push_back(resultID);
785  valueIDMap[op->getResult(0)] = resultID;
786 
787  for (Value operand : op->getOperands())
788  operands.push_back(getValueID(operand));
789  spirv::StorageClass resultStorage =
790  cast<spirv::PointerType>(resultTy).getStorageClass();
791  operands.push_back(static_cast<uint32_t>(resultStorage));
792  encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
793  operands);
794  return success();
795 }
796 
797 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
798 // various Serializer::processOp<...>() specializations.
799 #define GET_SERIALIZATION_FNS
800 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
801 
802 } // namespace spirv
803 } // namespace mlir
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.