MLIR  22.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "spirv-serialization"
24 
25 using namespace mlir;
26 
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
44  function_ref<LogicalResult(Block *)> blockHandler,
45  bool skipHeader = false, BlockRange skipBlocks = {}) {
46  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 
49  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50  if (skipHeader && block == headerBlock)
51  continue;
52  if (failed(blockHandler(block)))
53  return failure();
54  }
55  return success();
56 }
57 
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61  if (auto resultID =
62  prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63  valueIDMap[op.getResult()] = resultID;
64  return success();
65  }
66  return failure();
67 }
68 
69 LogicalResult Serializer::processConstantCompositeReplicateOp(
70  spirv::EXTConstantCompositeReplicateOp op) {
71  if (uint32_t resultID = prepareConstantCompositeReplicate(
72  op.getLoc(), op.getType(), op.getValue())) {
73  valueIDMap[op.getResult()] = resultID;
74  return success();
75  }
76  return failure();
77 }
78 
79 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
80  if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
81  /*isSpec=*/true)) {
82  // Emit the OpDecorate instruction for SpecId.
83  if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
84  auto val = static_cast<uint32_t>(specID.getInt());
85  if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
86  return failure();
87  }
88 
89  specConstIDMap[op.getSymName()] = resultID;
90  return processName(resultID, op.getSymName());
91  }
92  return failure();
93 }
94 
95 LogicalResult
96 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
97  uint32_t typeID = 0;
98  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
99  return failure();
100  }
101 
102  auto resultID = getNextID();
103 
104  SmallVector<uint32_t, 8> operands;
105  operands.push_back(typeID);
106  operands.push_back(resultID);
107 
108  auto constituents = op.getConstituents();
109 
110  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
111  auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
112 
113  auto constituentName = constituent.getValue();
114  auto constituentID = getSpecConstID(constituentName);
115 
116  if (!constituentID) {
117  return op.emitError("unknown result <id> for specialization constant ")
118  << constituentName;
119  }
120 
121  operands.push_back(constituentID);
122  }
123 
124  encodeInstructionInto(typesGlobalValues,
125  spirv::Opcode::OpSpecConstantComposite, operands);
126  specConstIDMap[op.getSymName()] = resultID;
127 
128  return processName(resultID, op.getSymName());
129 }
130 
131 LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
132  spirv::EXTSpecConstantCompositeReplicateOp op) {
133  uint32_t typeID = 0;
134  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
135  return failure();
136  }
137 
138  auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
139  if (!constituent)
140  return op.emitError(
141  "expected flat symbol reference for constituent instead of ")
142  << op.getConstituent();
143 
144  StringRef constituentName = constituent.getValue();
145  uint32_t constituentID = getSpecConstID(constituentName);
146  if (!constituentID) {
147  return op.emitError("unknown result <id> for replicated spec constant ")
148  << constituentName;
149  }
150 
151  uint32_t resultID = getNextID();
152  uint32_t operands[] = {typeID, resultID, constituentID};
153 
154  encodeInstructionInto(typesGlobalValues,
155  spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
156  operands);
157 
158  specConstIDMap[op.getSymName()] = resultID;
159 
160  return processName(resultID, op.getSymName());
161 }
162 
163 LogicalResult
164 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
165  uint32_t typeID = 0;
166  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
167  return failure();
168  }
169 
170  auto resultID = getNextID();
171 
172  SmallVector<uint32_t, 8> operands;
173  operands.push_back(typeID);
174  operands.push_back(resultID);
175 
176  Block &block = op.getRegion().getBlocks().front();
177  Operation &enclosedOp = block.getOperations().front();
178 
179  std::string enclosedOpName;
180  llvm::raw_string_ostream rss(enclosedOpName);
181  rss << "Op" << enclosedOp.getName().stripDialect();
182  auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
183 
184  if (!enclosedOpcode) {
185  op.emitError("Couldn't find op code for op ")
186  << enclosedOp.getName().getStringRef();
187  return failure();
188  }
189 
190  operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
191 
192  // Append operands to the enclosed op to the list of operands.
193  for (Value operand : enclosedOp.getOperands()) {
194  uint32_t id = getValueID(operand);
195  assert(id && "use before def!");
196  operands.push_back(id);
197  }
198 
199  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
200  operands);
201  valueIDMap[op.getResult()] = resultID;
202 
203  return success();
204 }
205 
206 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
207  auto undefType = op.getType();
208  auto &id = undefValIDMap[undefType];
209  if (!id) {
210  id = getNextID();
211  uint32_t typeID = 0;
212  if (failed(processType(op.getLoc(), undefType, typeID)))
213  return failure();
214  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
215  {typeID, id});
216  }
217  valueIDMap[op.getResult()] = id;
218  return success();
219 }
220 
221 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
222  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
223  uint32_t argTypeID = 0;
224  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
225  return failure();
226  }
227  auto argValueID = getNextID();
228 
229  // Process decoration attributes of arguments.
230  auto funcOp = cast<FunctionOpInterface>(*op);
231  for (auto argAttr : funcOp.getArgAttrs(idx)) {
232  if (argAttr.getName() != DecorationAttr::name)
233  continue;
234 
235  if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
236  if (failed(processDecorationAttr(op->getLoc(), argValueID,
237  decAttr.getValue(), decAttr)))
238  return failure();
239  }
240  }
241 
242  valueIDMap[arg] = argValueID;
243  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
244  {argTypeID, argValueID});
245  }
246  return success();
247 }
248 
249 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
250  LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
251  assert(functionHeader.empty() && functionBody.empty());
252 
253  uint32_t fnTypeID = 0;
254  // Generate type of the function.
255  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
256  return failure();
257 
258  // Add the function definition.
259  SmallVector<uint32_t, 4> operands;
260  uint32_t resTypeID = 0;
261  auto resultTypes = op.getFunctionType().getResults();
262  if (resultTypes.size() > 1) {
263  return op.emitError("cannot serialize function with multiple return types");
264  }
265  if (failed(processType(op.getLoc(),
266  (resultTypes.empty() ? getVoidType() : resultTypes[0]),
267  resTypeID))) {
268  return failure();
269  }
270  operands.push_back(resTypeID);
271  auto funcID = getOrCreateFunctionID(op.getName());
272  operands.push_back(funcID);
273  operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
274  operands.push_back(fnTypeID);
275  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
276 
277  // Add function name.
278  if (failed(processName(funcID, op.getName()))) {
279  return failure();
280  }
281  // Handle external functions with linkage_attributes(LinkageAttributes)
282  // differently.
283  auto linkageAttr = op.getLinkageAttributes();
284  auto hasImportLinkage =
285  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
286  spirv::LinkageType::Import);
287  if (op.isExternal() && !hasImportLinkage) {
288  return op.emitError(
289  "'spirv.module' cannot contain external functions "
290  "without 'Import' linkage_attributes (LinkageAttributes)");
291  }
292  if (op.isExternal() && hasImportLinkage) {
293  // Add an entry block to set up the block arguments
294  // to match the signature of the function.
295  // This is to generate OpFunctionParameter for functions with
296  // LinkageAttributes.
297  // WARNING: This operation has side-effect, it essentially adds a body
298  // to the func. Hence, making it not external anymore (isExternal()
299  // is going to return false for this function from now on)
300  // Hence, we'll remove the body once we are done with the serialization.
301  op.addEntryBlock();
302  if (failed(processFuncParameter(op)))
303  return failure();
304  // Don't need to process the added block, there is nothing to process,
305  // the fake body was added just to get the arguments, remove the body,
306  // since it's use is done.
307  op.eraseBody();
308  } else {
309  if (failed(processFuncParameter(op)))
310  return failure();
311 
312  // Some instructions (e.g., OpVariable) in a function must be in the first
313  // block in the function. These instructions will be put in
314  // functionHeader. Thus, we put the label in functionHeader first, and
315  // omit it from the first block. OpLabel only needs to be added for
316  // functions with body (including empty body). Since, we added a fake body
317  // for functions with 'Import' Linkage attributes, these functions are
318  // essentially function delcaration, so they should not have OpLabel and a
319  // terminating instruction. That's why we skipped it for those functions.
320  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
321  {getOrCreateBlockID(&op.front())});
322  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
323  return failure();
325  &op.front(), [&](Block *block) { return processBlock(block); },
326  /*skipHeader=*/true))) {
327  return failure();
328  }
329 
330  // There might be OpPhi instructions who have value references needing to
331  // fix.
332  for (const auto &deferredValue : deferredPhiValues) {
333  Value value = deferredValue.first;
334  uint32_t id = getValueID(value);
335  LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
336  << " to id = " << id << '\n');
337  assert(id && "OpPhi references undefined value!");
338  for (size_t offset : deferredValue.second)
339  functionBody[offset] = id;
340  }
341  deferredPhiValues.clear();
342  }
343  LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
344  << "' --\n");
345  // Insert Decorations based on Function Attributes.
346  // Only attributes we should be considering for decoration are the
347  // ::mlir::spirv::Decoration attributes.
348 
349  for (auto attr : op->getAttrs()) {
350  // Only generate OpDecorate op for spirv::Decoration attributes.
351  auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
352  llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
353  /*capitalizeFirst=*/true));
354  if (isValidDecoration != std::nullopt) {
355  if (failed(processDecoration(op.getLoc(), funcID, attr))) {
356  return failure();
357  }
358  }
359  }
360  // Insert OpFunctionEnd.
361  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
362 
363  functions.append(functionHeader.begin(), functionHeader.end());
364  functions.append(functionBody.begin(), functionBody.end());
365  functionHeader.clear();
366  functionBody.clear();
367 
368  return success();
369 }
370 
371 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
372  SmallVector<uint32_t, 4> operands;
373  SmallVector<StringRef, 2> elidedAttrs;
374  uint32_t resultID = 0;
375  uint32_t resultTypeID = 0;
376  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
377  return failure();
378  }
379  operands.push_back(resultTypeID);
380  resultID = getNextID();
381  valueIDMap[op.getResult()] = resultID;
382  operands.push_back(resultID);
383  auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
384  if (attr) {
385  operands.push_back(
386  static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
387  }
388  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
389  for (auto arg : op.getODSOperands(0)) {
390  auto argID = getValueID(arg);
391  if (!argID) {
392  return emitError(op.getLoc(), "operand 0 has a use before def");
393  }
394  operands.push_back(argID);
395  }
396  if (failed(emitDebugLine(functionHeader, op.getLoc())))
397  return failure();
398  encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
399  for (auto attr : op->getAttrs()) {
400  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
401  return attr.getName() == elided;
402  })) {
403  continue;
404  }
405  if (failed(processDecoration(op.getLoc(), resultID, attr))) {
406  return failure();
407  }
408  }
409  return success();
410 }
411 
412 LogicalResult
413 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
414  // Get TypeID.
415  uint32_t resultTypeID = 0;
416  SmallVector<StringRef, 4> elidedAttrs;
417  if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
418  return failure();
419  }
420 
421  elidedAttrs.push_back("type");
422  SmallVector<uint32_t, 4> operands;
423  operands.push_back(resultTypeID);
424  auto resultID = getNextID();
425 
426  // Encode the name.
427  auto varName = varOp.getSymName();
428  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
429  if (failed(processName(resultID, varName))) {
430  return failure();
431  }
432  globalVarIDMap[varName] = resultID;
433  operands.push_back(resultID);
434 
435  // Encode StorageClass.
436  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
437 
438  // Encode initialization.
439  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
440  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
441  uint32_t initializerID = 0;
442  auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
444  varOp->getParentOp(), initRef.getAttr());
445 
446  // Check if initializer is GlobalVariable or SpecConstant* cases.
447  if (isa<spirv::GlobalVariableOp>(initOp))
448  initializerID = getVariableID(*initSymbolName);
449  else
450  initializerID = getSpecConstID(*initSymbolName);
451 
452  if (!initializerID)
453  return emitError(varOp.getLoc(),
454  "invalid usage of undefined variable as initializer");
455 
456  operands.push_back(initializerID);
457  elidedAttrs.push_back(initAttrName);
458  }
459 
460  if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
461  return failure();
462  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
463  elidedAttrs.push_back(initAttrName);
464 
465  // Encode decorations.
466  for (auto attr : varOp->getAttrs()) {
467  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
468  return attr.getName() == elided;
469  })) {
470  continue;
471  }
472  if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
473  return failure();
474  }
475  }
476  return success();
477 }
478 
479 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
480  // Assign <id>s to all blocks so that branches inside the SelectionOp can
481  // resolve properly.
482  auto &body = selectionOp.getBody();
483  for (Block &block : body)
484  getOrCreateBlockID(&block);
485 
486  auto *headerBlock = selectionOp.getHeaderBlock();
487  auto *mergeBlock = selectionOp.getMergeBlock();
488  auto headerID = getBlockID(headerBlock);
489  auto mergeID = getBlockID(mergeBlock);
490  auto loc = selectionOp.getLoc();
491 
492  // Before we do anything replace results of the selection operation with
493  // values yielded (with `mlir.merge`) from inside the region. The selection op
494  // is being flattened so we do not have to worry about values being defined
495  // inside a region and used outside it anymore.
496  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
497  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
498  for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
499  selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
500 
501  // This SelectionOp is in some MLIR block with preceding and following ops. In
502  // the binary format, it should reside in separate SPIR-V blocks from its
503  // preceding and following ops. So we need to emit unconditional branches to
504  // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
505  // flow afterwards.
506  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
507 
508  // Emit the selection header block, which dominates all other blocks, first.
509  // We need to emit an OpSelectionMerge instruction before the selection header
510  // block's terminator.
511  auto emitSelectionMerge = [&]() {
512  if (failed(emitDebugLine(functionBody, loc)))
513  return failure();
514  lastProcessedWasMergeInst = true;
516  functionBody, spirv::Opcode::OpSelectionMerge,
517  {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
518  return success();
519  };
520  if (failed(
521  processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
522  return failure();
523 
524  // Process all blocks with a depth-first visitor starting from the header
525  // block. The selection header block and merge block are skipped by this
526  // visitor.
528  headerBlock, [&](Block *block) { return processBlock(block); },
529  /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
530  return failure();
531 
532  // There is nothing to do for the merge block in the selection, which just
533  // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
534  // instruction to start a new SPIR-V block for ops following this SelectionOp.
535  // The block should use the <id> for the merge block.
536  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
537 
538  // We do not process the mergeBlock but we still need to generate phi
539  // functions from its block arguments.
540  if (failed(emitPhiForBlockArguments(mergeBlock)))
541  return failure();
542 
543  LLVM_DEBUG(llvm::dbgs() << "done merge ");
544  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
545  LLVM_DEBUG(llvm::dbgs() << "\n");
546  return success();
547 }
548 
549 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
550  // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
551  // properly. We don't need to assign for the entry block, which is just for
552  // satisfying MLIR region's structural requirement.
553  auto &body = loopOp.getBody();
554  for (Block &block : llvm::drop_begin(body))
555  getOrCreateBlockID(&block);
556 
557  auto *headerBlock = loopOp.getHeaderBlock();
558  auto *continueBlock = loopOp.getContinueBlock();
559  auto *mergeBlock = loopOp.getMergeBlock();
560  auto headerID = getBlockID(headerBlock);
561  auto continueID = getBlockID(continueBlock);
562  auto mergeID = getBlockID(mergeBlock);
563  auto loc = loopOp.getLoc();
564 
565  // Before we do anything replace results of the selection operation with
566  // values yielded (with `mlir.merge`) from inside the region.
567  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
568  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
569  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
570  loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
571 
572  // This LoopOp is in some MLIR block with preceding and following ops. In the
573  // binary format, it should reside in separate SPIR-V blocks from its
574  // preceding and following ops. So we need to emit unconditional branches to
575  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
576  // afterwards.
577  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
578 
579  // LoopOp's entry block is just there for satisfying MLIR's structural
580  // requirements so we omit it and start serialization from the loop header
581  // block.
582 
583  // Emit the loop header block, which dominates all other blocks, first. We
584  // need to emit an OpLoopMerge instruction before the loop header block's
585  // terminator.
586  auto emitLoopMerge = [&]() {
587  if (failed(emitDebugLine(functionBody, loc)))
588  return failure();
589  lastProcessedWasMergeInst = true;
591  functionBody, spirv::Opcode::OpLoopMerge,
592  {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
593  return success();
594  };
595  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
596  return failure();
597 
598  // Process all blocks with a depth-first visitor starting from the header
599  // block. The loop header block, loop continue block, and loop merge block are
600  // skipped by this visitor and handled later in this function.
602  headerBlock, [&](Block *block) { return processBlock(block); },
603  /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
604  return failure();
605 
606  // We have handled all other blocks. Now get to the loop continue block.
607  if (failed(processBlock(continueBlock)))
608  return failure();
609 
610  // There is nothing to do for the merge block in the loop, which just contains
611  // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
612  // to start a new SPIR-V block for ops following this LoopOp. The block should
613  // use the <id> for the merge block.
614  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
615  LLVM_DEBUG(llvm::dbgs() << "done merge ");
616  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
617  LLVM_DEBUG(llvm::dbgs() << "\n");
618  return success();
619 }
620 
621 LogicalResult Serializer::processBranchConditionalOp(
622  spirv::BranchConditionalOp condBranchOp) {
623  auto conditionID = getValueID(condBranchOp.getCondition());
624  auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
625  auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
626  SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
627 
628  if (auto weights = condBranchOp.getBranchWeights()) {
629  for (auto val : weights->getValue())
630  arguments.push_back(cast<IntegerAttr>(val).getInt());
631  }
632 
633  if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
634  return failure();
635  encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
636  arguments);
637  return success();
638 }
639 
640 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
641  if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
642  return failure();
643  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
644  {getOrCreateBlockID(branchOp.getTarget())});
645  return success();
646 }
647 
648 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
649  auto varName = addressOfOp.getVariable();
650  auto variableID = getVariableID(varName);
651  if (!variableID) {
652  return addressOfOp.emitError("unknown result <id> for variable ")
653  << varName;
654  }
655  valueIDMap[addressOfOp.getPointer()] = variableID;
656  return success();
657 }
658 
659 LogicalResult
660 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
661  auto constName = referenceOfOp.getSpecConst();
662  auto constID = getSpecConstID(constName);
663  if (!constID) {
664  return referenceOfOp.emitError(
665  "unknown result <id> for specialization constant ")
666  << constName;
667  }
668  valueIDMap[referenceOfOp.getReference()] = constID;
669  return success();
670 }
671 
672 template <>
673 LogicalResult
674 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
675  SmallVector<uint32_t, 4> operands;
676  // Add the ExecutionModel.
677  operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
678  // Add the function <id>.
679  auto funcID = getFunctionID(op.getFn());
680  if (!funcID) {
681  return op.emitError("missing <id> for function ")
682  << op.getFn()
683  << "; function needs to be defined before spirv.EntryPoint is "
684  "serialized";
685  }
686  operands.push_back(funcID);
687  // Add the name of the function.
688  spirv::encodeStringLiteralInto(operands, op.getFn());
689 
690  // Add the interface values.
691  if (auto interface = op.getInterface()) {
692  for (auto var : interface.getValue()) {
693  auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
694  if (!id) {
695  return op.emitError(
696  "referencing undefined global variable."
697  "spirv.EntryPoint is at the end of spirv.module. All "
698  "referenced variables should already be defined");
699  }
700  operands.push_back(id);
701  }
702  }
703  encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
704  return success();
705 }
706 
707 template <>
708 LogicalResult
709 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
710  SmallVector<uint32_t, 4> operands;
711  // Add the function <id>.
712  auto funcID = getFunctionID(op.getFn());
713  if (!funcID) {
714  return op.emitError("missing <id> for function ")
715  << op.getFn()
716  << "; function needs to be serialized before ExecutionModeOp is "
717  "serialized";
718  }
719  operands.push_back(funcID);
720  // Add the ExecutionMode.
721  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
722 
723  // Serialize values if any.
724  auto values = op.getValues();
725  if (values) {
726  for (auto &intVal : values.getValue()) {
727  operands.push_back(static_cast<uint32_t>(
728  llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
729  }
730  }
731  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
732  operands);
733  return success();
734 }
735 
736 template <>
737 LogicalResult
738 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
739  auto funcName = op.getCallee();
740  uint32_t resTypeID = 0;
741 
742  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
743  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
744  return failure();
745 
746  auto funcID = getOrCreateFunctionID(funcName);
747  auto funcCallID = getNextID();
748  SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
749 
750  for (auto value : op.getArguments()) {
751  auto valueID = getValueID(value);
752  assert(valueID && "cannot find a value for spirv.FunctionCall");
753  operands.push_back(valueID);
754  }
755 
756  if (!isa<NoneType>(resultTy))
757  valueIDMap[op.getResult(0)] = funcCallID;
758 
759  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
760  return success();
761 }
762 
763 template <>
764 LogicalResult
765 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
766  SmallVector<uint32_t, 4> operands;
767  SmallVector<StringRef, 2> elidedAttrs;
768 
769  for (Value operand : op->getOperands()) {
770  auto id = getValueID(operand);
771  assert(id && "use before def!");
772  operands.push_back(id);
773  }
774 
775  StringAttr memoryAccess = op.getMemoryAccessAttrName();
776  if (auto attr = op->getAttr(memoryAccess)) {
777  operands.push_back(
778  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
779  }
780 
781  elidedAttrs.push_back(memoryAccess.strref());
782 
783  StringAttr alignment = op.getAlignmentAttrName();
784  if (auto attr = op->getAttr(alignment)) {
785  operands.push_back(static_cast<uint32_t>(
786  cast<IntegerAttr>(attr).getValue().getZExtValue()));
787  }
788 
789  elidedAttrs.push_back(alignment.strref());
790 
791  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
792  if (auto attr = op->getAttr(sourceMemoryAccess)) {
793  operands.push_back(
794  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
795  }
796 
797  elidedAttrs.push_back(sourceMemoryAccess.strref());
798 
799  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
800  if (auto attr = op->getAttr(sourceAlignment)) {
801  operands.push_back(static_cast<uint32_t>(
802  cast<IntegerAttr>(attr).getValue().getZExtValue()));
803  }
804 
805  elidedAttrs.push_back(sourceAlignment.strref());
806  if (failed(emitDebugLine(functionBody, op.getLoc())))
807  return failure();
808  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
809 
810  return success();
811 }
812 template <>
813 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
814  spirv::GenericCastToPtrExplicitOp op) {
815  SmallVector<uint32_t, 4> operands;
816  Type resultTy;
817  Location loc = op->getLoc();
818  uint32_t resultTypeID = 0;
819  uint32_t resultID = 0;
820  resultTy = op->getResult(0).getType();
821  if (failed(processType(loc, resultTy, resultTypeID)))
822  return failure();
823  operands.push_back(resultTypeID);
824 
825  resultID = getNextID();
826  operands.push_back(resultID);
827  valueIDMap[op->getResult(0)] = resultID;
828 
829  for (Value operand : op->getOperands())
830  operands.push_back(getValueID(operand));
831  spirv::StorageClass resultStorage =
832  cast<spirv::PointerType>(resultTy).getStorageClass();
833  operands.push_back(static_cast<uint32_t>(resultStorage));
834  encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
835  operands);
836  return success();
837 }
838 
839 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
840 // various Serializer::processOp<...>() specializations.
841 #define GET_SERIALIZATION_FNS
842 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
843 
844 } // namespace spirv
845 } // namespace mlir
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:97
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.