MLIR  18.0.0git
MemoryOps.cpp
Go to the documentation of this file.
1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the memory operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
19 #include "mlir/IR/Diagnostics.h"
20 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23 
24 using namespace mlir::spirv::AttrNames;
25 
26 namespace mlir::spirv {
27 
28 // TODO Make sure to merge this and the previous function into one template
29 // parameterized by memory access attribute name and alignment. Doing so now
30 // results in VS2017 in producing an internal error (at the call site) that's
31 // not detailed enough to understand what is happening.
33  OperationState &state) {
34  // Parse an optional list of attributes staring with '['
35  if (parser.parseOptionalLSquare()) {
36  // Nothing to do
37  return success();
38  }
39 
40  spirv::MemoryAccess memoryAccessAttr;
41  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
42  memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
43  return failure();
44 
45  if (spirv::bitEnumContainsAll(memoryAccessAttr,
46  spirv::MemoryAccess::Aligned)) {
47  // Parse integer attribute for alignment.
48  Attribute alignmentAttr;
49  Type i32Type = parser.getBuilder().getIntegerType(32);
50  if (parser.parseComma() ||
51  parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
52  state.attributes)) {
53  return failure();
54  }
55  }
56  return parser.parseRSquare();
57 }
58 
59 // TODO Make sure to merge this and the previous function into one template
60 // parameterized by memory access attribute name and alignment. Doing so now
61 // results in VS2017 in producing an internal error (at the call site) that's
62 // not detailed enough to understand what is happening.
63 template <typename MemoryOpTy>
65  MemoryOpTy memoryOp, OpAsmPrinter &printer,
66  SmallVectorImpl<StringRef> &elidedAttrs,
67  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
68  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
69 
70  printer << ", ";
71 
72  // Print optional memory access attribute.
73  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
74  : memoryOp.getMemoryAccess())) {
75  elidedAttrs.push_back(kSourceMemoryAccessAttrName);
76 
77  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
78 
79  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
80  // Print integer alignment attribute.
81  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
82  : memoryOp.getAlignment())) {
83  elidedAttrs.push_back(kSourceAlignmentAttrName);
84  printer << ", " << *alignment;
85  }
86  }
87  printer << "]";
88  }
89  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
90 }
91 
92 template <typename MemoryOpTy>
94  MemoryOpTy memoryOp, OpAsmPrinter &printer,
95  SmallVectorImpl<StringRef> &elidedAttrs,
96  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
97  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
98  // Print optional memory access attribute.
99  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
100  : memoryOp.getMemoryAccess())) {
101  elidedAttrs.push_back(kMemoryAccessAttrName);
102 
103  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
104 
105  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
106  // Print integer alignment attribute.
107  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
108  : memoryOp.getAlignment())) {
109  elidedAttrs.push_back(kAlignmentAttrName);
110  printer << ", " << *alignment;
111  }
112  }
113  printer << "]";
114  }
115  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
116 }
117 
118 template <typename LoadStoreOpTy>
119 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
120  Value val) {
121  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
122  // type of the pointer and the type of the value are the same
123  //
124  // TODO: Check that the value type satisfies restrictions of
125  // SPIR-V OpLoad/OpStore operations
126  if (val.getType() !=
127  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
128  return op.emitOpError("mismatch in result type and pointer type");
129  }
130  return success();
131 }
132 
133 template <typename MemoryOpTy>
134 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
135  // ODS checks for attributes values. Just need to verify that if the
136  // memory-access attribute is Aligned, then the alignment attribute must be
137  // present.
138  auto *op = memoryOp.getOperation();
139  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
140  if (!memAccessAttr) {
141  // Alignment attribute shouldn't be present if memory access attribute is
142  // not present.
143  if (op->getAttr(kAlignmentAttrName)) {
144  return memoryOp.emitOpError(
145  "invalid alignment specification without aligned memory access "
146  "specification");
147  }
148  return success();
149  }
150 
151  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
152 
153  if (!memAccess) {
154  return memoryOp.emitOpError("invalid memory access specifier: ")
155  << memAccessAttr;
156  }
157 
158  if (spirv::bitEnumContainsAll(memAccess.getValue(),
159  spirv::MemoryAccess::Aligned)) {
160  if (!op->getAttr(kAlignmentAttrName)) {
161  return memoryOp.emitOpError("missing alignment value");
162  }
163  } else {
164  if (op->getAttr(kAlignmentAttrName)) {
165  return memoryOp.emitOpError(
166  "invalid alignment specification with non-aligned memory access "
167  "specification");
168  }
169  }
170  return success();
171 }
172 
173 // TODO Make sure to merge this and the previous function into one template
174 // parameterized by memory access attribute name and alignment. Doing so now
175 // results in VS2017 in producing an internal error (at the call site) that's
176 // not detailed enough to understand what is happening.
177 template <typename MemoryOpTy>
179  // ODS checks for attributes values. Just need to verify that if the
180  // memory-access attribute is Aligned, then the alignment attribute must be
181  // present.
182  auto *op = memoryOp.getOperation();
183  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
184  if (!memAccessAttr) {
185  // Alignment attribute shouldn't be present if memory access attribute is
186  // not present.
187  if (op->getAttr(kSourceAlignmentAttrName)) {
188  return memoryOp.emitOpError(
189  "invalid alignment specification without aligned memory access "
190  "specification");
191  }
192  return success();
193  }
194 
195  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
196 
197  if (!memAccess) {
198  return memoryOp.emitOpError("invalid memory access specifier: ")
199  << memAccess;
200  }
201 
202  if (spirv::bitEnumContainsAll(memAccess.getValue(),
203  spirv::MemoryAccess::Aligned)) {
204  if (!op->getAttr(kSourceAlignmentAttrName)) {
205  return memoryOp.emitOpError("missing alignment value");
206  }
207  } else {
208  if (op->getAttr(kSourceAlignmentAttrName)) {
209  return memoryOp.emitOpError(
210  "invalid alignment specification with non-aligned memory access "
211  "specification");
212  }
213  }
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // spirv.AccessChainOp
219 //===----------------------------------------------------------------------===//
220 
221 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
222  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
223  if (!ptrType) {
224  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
225  "to composite type, but provided ")
226  << type;
227  return nullptr;
228  }
229 
230  auto resultType = ptrType.getPointeeType();
231  auto resultStorageClass = ptrType.getStorageClass();
232  int32_t index = 0;
233 
234  for (auto indexSSA : indices) {
235  auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
236  if (!cType) {
237  emitError(
238  baseLoc,
239  "'spirv.AccessChain' op cannot extract from non-composite type ")
240  << resultType << " with index " << index;
241  return nullptr;
242  }
243  index = 0;
244  if (llvm::isa<spirv::StructType>(resultType)) {
245  Operation *op = indexSSA.getDefiningOp();
246  if (!op) {
247  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
248  "integer spirv.Constant to access "
249  "element of spirv.struct");
250  return nullptr;
251  }
252 
253  // TODO: this should be relaxed to allow
254  // integer literals of other bitwidths.
255  if (failed(spirv::extractValueFromConstOp(op, index))) {
256  emitError(
257  baseLoc,
258  "'spirv.AccessChain' index must be an integer spirv.Constant to "
259  "access element of spirv.struct, but provided ")
260  << op->getName();
261  return nullptr;
262  }
263  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
264  emitError(baseLoc, "'spirv.AccessChain' op index ")
265  << index << " out of bounds for " << resultType;
266  return nullptr;
267  }
268  }
269  resultType = cType.getElementType(index);
270  }
271  return spirv::PointerType::get(resultType, resultStorageClass);
272 }
273 
274 void AccessChainOp::build(OpBuilder &builder, OperationState &state,
275  Value basePtr, ValueRange indices) {
276  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
277  assert(type && "Unable to deduce return type based on basePtr and indices");
278  build(builder, state, type, basePtr, indices);
279 }
280 
281 ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
282  OpAsmParser::UnresolvedOperand ptrInfo;
283  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
284  Type type;
285  auto loc = parser.getCurrentLocation();
286  SmallVector<Type, 4> indicesTypes;
287 
288  if (parser.parseOperand(ptrInfo) ||
289  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
290  parser.parseColonType(type) ||
291  parser.resolveOperand(ptrInfo, type, result.operands)) {
292  return failure();
293  }
294 
295  // Check that the provided indices list is not empty before parsing their
296  // type list.
297  if (indicesInfo.empty()) {
298  return mlir::emitError(result.location,
299  "'spirv.AccessChain' op expected at "
300  "least one index ");
301  }
302 
303  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
304  return failure();
305 
306  // Check that the indices types list is not empty and that it has a one-to-one
307  // mapping to the provided indices.
308  if (indicesTypes.size() != indicesInfo.size()) {
309  return mlir::emitError(
310  result.location, "'spirv.AccessChain' op indices types' count must be "
311  "equal to indices info count");
312  }
313 
314  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
315  return failure();
316 
317  auto resultType = getElementPtrType(
318  type, llvm::ArrayRef(result.operands).drop_front(), result.location);
319  if (!resultType) {
320  return failure();
321  }
322 
323  result.addTypes(resultType);
324  return success();
325 }
326 
327 template <typename Op>
328 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
329  printer << ' ' << op.getBasePtr() << '[' << indices
330  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
331 }
332 
334  printAccessChain(*this, getIndices(), printer);
335 }
336 
337 template <typename Op>
338 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
339  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
340  indices, accessChainOp.getLoc());
341  if (!resultType)
342  return failure();
343 
344  auto providedResultType =
345  llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
346  if (!providedResultType)
347  return accessChainOp.emitOpError(
348  "result type must be a pointer, but provided")
349  << providedResultType;
350 
351  if (resultType != providedResultType)
352  return accessChainOp.emitOpError("invalid result type: expected ")
353  << resultType << ", but provided " << providedResultType;
354 
355  return success();
356 }
357 
359  return verifyAccessChain(*this, getIndices());
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // spirv.LoadOp
364 //===----------------------------------------------------------------------===//
365 
366 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
367  MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
368  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
369  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
370  alignment);
371 }
372 
373 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
374  // Parse the storage class specification
375  spirv::StorageClass storageClass;
376  OpAsmParser::UnresolvedOperand ptrInfo;
377  Type elementType;
378  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
379  parseMemoryAccessAttributes(parser, result) ||
380  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
381  parser.parseType(elementType)) {
382  return failure();
383  }
384 
385  auto ptrType = spirv::PointerType::get(elementType, storageClass);
386  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
387  return failure();
388  }
389 
390  result.addTypes(elementType);
391  return success();
392 }
393 
394 void LoadOp::print(OpAsmPrinter &printer) {
395  SmallVector<StringRef, 4> elidedAttrs;
396  StringRef sc = stringifyStorageClass(
397  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
398  printer << " \"" << sc << "\" " << getPtr();
399 
400  printMemoryAccessAttribute(*this, printer, elidedAttrs);
401 
402  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
403  printer << " : " << getType();
404 }
405 
406 LogicalResult LoadOp::verify() {
407  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
408  // type with fixed size; i.e., it cannot be, nor include, any
409  // OpTypeRuntimeArray types."
410  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
411  return failure();
412  }
413  return verifyMemoryAccessAttribute(*this);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // spirv.StoreOp
418 //===----------------------------------------------------------------------===//
419 
420 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
421  // Parse the storage class specification
422  spirv::StorageClass storageClass;
423  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
424  auto loc = parser.getCurrentLocation();
425  Type elementType;
426  if (parseEnumStrAttr(storageClass, parser) ||
427  parser.parseOperandList(operandInfo, 2) ||
428  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
429  parser.parseType(elementType)) {
430  return failure();
431  }
432 
433  auto ptrType = spirv::PointerType::get(elementType, storageClass);
434  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
435  result.operands)) {
436  return failure();
437  }
438  return success();
439 }
440 
441 void StoreOp::print(OpAsmPrinter &printer) {
442  SmallVector<StringRef, 4> elidedAttrs;
443  StringRef sc = stringifyStorageClass(
444  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
445  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
446 
447  printMemoryAccessAttribute(*this, printer, elidedAttrs);
448 
449  printer << " : " << getValue().getType();
450  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
451 }
452 
453 LogicalResult StoreOp::verify() {
454  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
455  // OpTypePointer whose Type operand is the same as the type of Object."
456  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
457  return failure();
458  return verifyMemoryAccessAttribute(*this);
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // spirv.CopyMemory
463 //===----------------------------------------------------------------------===//
464 
465 void CopyMemoryOp::print(OpAsmPrinter &printer) {
466  printer << ' ';
467 
468  StringRef targetStorageClass = stringifyStorageClass(
469  llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
470  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
471 
472  StringRef sourceStorageClass = stringifyStorageClass(
473  llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
474  printer << " \"" << sourceStorageClass << "\" " << getSource();
475 
476  SmallVector<StringRef, 4> elidedAttrs;
477  printMemoryAccessAttribute(*this, printer, elidedAttrs);
478  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
479  getSourceMemoryAccess(),
480  getSourceAlignment());
481 
482  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
483 
484  Type pointeeType =
485  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
486  printer << " : " << pointeeType;
487 }
488 
489 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
490  spirv::StorageClass targetStorageClass;
491  OpAsmParser::UnresolvedOperand targetPtrInfo;
492 
493  spirv::StorageClass sourceStorageClass;
494  OpAsmParser::UnresolvedOperand sourcePtrInfo;
495 
496  Type elementType;
497 
498  if (parseEnumStrAttr(targetStorageClass, parser) ||
499  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
500  parseEnumStrAttr(sourceStorageClass, parser) ||
501  parser.parseOperand(sourcePtrInfo) ||
502  parseMemoryAccessAttributes(parser, result)) {
503  return failure();
504  }
505 
506  if (!parser.parseOptionalComma()) {
507  // Parse 2nd memory access attributes.
508  if (parseSourceMemoryAccessAttributes(parser, result)) {
509  return failure();
510  }
511  }
512 
513  if (parser.parseColon() || parser.parseType(elementType))
514  return failure();
515 
516  if (parser.parseOptionalAttrDict(result.attributes))
517  return failure();
518 
519  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
520  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
521 
522  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
523  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
524  return failure();
525  }
526 
527  return success();
528 }
529 
530 LogicalResult CopyMemoryOp::verify() {
531  Type targetType =
532  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
533 
534  Type sourceType =
535  llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
536 
537  if (targetType != sourceType)
538  return emitOpError("both operands must be pointers to the same type");
539 
541  return failure();
542 
543  // TODO - According to the spec:
544  //
545  // If two masks are present, the first applies to Target and cannot include
546  // MakePointerVisible, and the second applies to Source and cannot include
547  // MakePointerAvailable.
548  //
549  // Add such verification here.
550 
551  return verifySourceMemoryAccessAttribute(*this);
552 }
553 
554 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
555  OpAsmParser &parser,
556  OperationState &state) {
559  Type type;
560  auto loc = parser.getCurrentLocation();
561  SmallVector<Type, 4> indicesTypes;
562 
563  if (parser.parseOperand(ptrInfo) ||
564  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
565  parser.parseColonType(type) ||
566  parser.resolveOperand(ptrInfo, type, state.operands))
567  return failure();
568 
569  // Check that the provided indices list is not empty before parsing their
570  // type list.
571  if (indicesInfo.empty())
572  return emitError(state.location) << opName << " expected element";
573 
574  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
575  return failure();
576 
577  // Check that the indices types list is not empty and that it has a one-to-one
578  // mapping to the provided indices.
579  if (indicesTypes.size() != indicesInfo.size())
580  return emitError(state.location)
581  << opName
582  << " indices types' count must be equal to indices info count";
583 
584  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
585  return failure();
586 
587  auto resultType = getElementPtrType(
588  type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
589  if (!resultType)
590  return failure();
591 
592  state.addTypes(resultType);
593  return success();
594 }
595 
596 template <typename Op>
597 static auto concatElemAndIndices(Op op) {
598  SmallVector<Value> ret(op.getIndices().size() + 1);
599  ret[0] = op.getElement();
600  llvm::copy(op.getIndices(), ret.begin() + 1);
601  return ret;
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // spirv.InBoundsPtrAccessChainOp
606 //===----------------------------------------------------------------------===//
607 
608 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
609  Value basePtr, Value element,
610  ValueRange indices) {
611  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
612  assert(type && "Unable to deduce return type based on basePtr and indices");
613  build(builder, state, type, basePtr, element, indices);
614 }
615 
616 ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
617  OperationState &result) {
619  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
620 }
621 
622 void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
623  printAccessChain(*this, concatElemAndIndices(*this), printer);
624 }
625 
626 LogicalResult InBoundsPtrAccessChainOp::verify() {
627  return verifyAccessChain(*this, getIndices());
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // spirv.PtrAccessChainOp
632 //===----------------------------------------------------------------------===//
633 
634 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
635  Value basePtr, Value element, ValueRange indices) {
636  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
637  assert(type && "Unable to deduce return type based on basePtr and indices");
638  build(builder, state, type, basePtr, element, indices);
639 }
640 
641 ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
642  OperationState &result) {
643  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
644  parser, result);
645 }
646 
647 void PtrAccessChainOp::print(OpAsmPrinter &printer) {
648  printAccessChain(*this, concatElemAndIndices(*this), printer);
649 }
650 
651 LogicalResult PtrAccessChainOp::verify() {
652  return verifyAccessChain(*this, getIndices());
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // spirv.Variable
657 //===----------------------------------------------------------------------===//
658 
659 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
660  // Parse optional initializer
661  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
662  if (succeeded(parser.parseOptionalKeyword("init"))) {
663  initInfo = OpAsmParser::UnresolvedOperand();
664  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
665  parser.parseRParen())
666  return failure();
667  }
668 
669  if (parseVariableDecorations(parser, result)) {
670  return failure();
671  }
672 
673  // Parse result pointer type
674  Type type;
675  if (parser.parseColon())
676  return failure();
677  auto loc = parser.getCurrentLocation();
678  if (parser.parseType(type))
679  return failure();
680 
681  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
682  if (!ptrType)
683  return parser.emitError(loc, "expected spirv.ptr type");
684  result.addTypes(ptrType);
685 
686  // Resolve the initializer operand
687  if (initInfo) {
688  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
689  result.operands))
690  return failure();
691  }
692 
693  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
694  ptrType.getStorageClass());
695  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
696 
697  return success();
698 }
699 
700 void VariableOp::print(OpAsmPrinter &printer) {
701  SmallVector<StringRef, 4> elidedAttrs{
702  spirv::attributeName<spirv::StorageClass>()};
703  // Print optional initializer
704  if (getNumOperands() != 0)
705  printer << " init(" << getInitializer() << ")";
706 
707  printVariableDecorations(*this, printer, elidedAttrs);
708  printer << " : " << getType();
709 }
710 
711 LogicalResult VariableOp::verify() {
712  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
713  // object. It cannot be Generic. It must be the same as the Storage Class
714  // operand of the Result Type."
715  if (getStorageClass() != spirv::StorageClass::Function) {
716  return emitOpError(
717  "can only be used to model function-level variables. Use "
718  "spirv.GlobalVariable for module-level variables.");
719  }
720 
721  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
722  if (getStorageClass() != pointerType.getStorageClass())
723  return emitOpError(
724  "storage class must match result pointer's storage class");
725 
726  if (getNumOperands() != 0) {
727  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
728  // a global (module scope) OpVariable instruction".
729  auto *initOp = getOperand(0).getDefiningOp();
730  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
731  spirv::ReferenceOfOp, // for spec constant
732  spirv::AddressOfOp>(initOp))
733  return emitOpError("initializer must be the result of a "
734  "constant or spirv.GlobalVariable op");
735  }
736 
737  auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
738  return op->getAttr(
739  llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
740  };
741 
742  // TODO: generate these strings using ODS.
743  for (auto decoration :
744  {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
745  spirv::Decoration::BuiltIn}) {
746  if (auto attr = getDecorationAttr(decoration))
747  return emitOpError("cannot have '")
748  << llvm::convertToSnakeFromCamelCase(
749  stringifyDecoration(decoration))
750  << "' attribute (only allowed in spirv.GlobalVariable)";
751  }
752 
753  // From SPV_KHR_physical_storage_buffer:
754  // > If an OpVariable's pointee type is a pointer (or array of pointers) in
755  // > PhysicalStorageBuffer storage class, then the variable must be decorated
756  // > with exactly one of AliasedPointer or RestrictPointer.
757  auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
758  if (!pointeePtrType) {
759  if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
760  pointeePtrType =
761  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
762  }
763  }
764 
765  if (pointeePtrType && pointeePtrType.getStorageClass() ==
766  spirv::StorageClass::PhysicalStorageBuffer) {
767  bool hasAliasedPtr =
768  getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
769  bool hasRestrictPtr =
770  getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
771 
772  if (!hasAliasedPtr && !hasRestrictPtr)
773  return emitOpError() << " with physical buffer pointer must be decorated "
774  "either 'AliasedPointer' or 'RestrictPointer'";
775 
776  if (hasAliasedPtr && hasRestrictPtr)
777  return emitOpError()
778  << " with physical buffer pointer must have exactly one "
779  "aliasing decoration";
780  }
781 
782  return success();
783 }
784 
785 } // namespace mlir::spirv
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:72
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:206
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:781
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents success/failure for parsing-like operations that find it important to chain tog...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:546
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
constexpr char kMemoryAccessAttrName[]
constexpr char kSourceMemoryAccessAttrName[]
constexpr char kSourceAlignmentAttrName[]
constexpr char kAlignmentAttrName[]
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:178
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:554
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state, StringRef attrName)
Parses optional memory access (a.k.a.
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:64
static auto concatElemAndIndices(Op op)
Definition: MemoryOps.cpp:597
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:221
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: MemoryOps.cpp:119
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:134
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: MemoryOps.cpp:328
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:32
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:93
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: MemoryOps.cpp:338
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.