MLIR  20.0.0git
MemoryOps.cpp
Go to the documentation of this file.
1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the memory operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
19 #include "mlir/IR/Diagnostics.h"
20 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23 
24 using namespace mlir::spirv::AttrNames;
25 
26 namespace mlir::spirv {
27 
28 /// Parses optional memory access (a.k.a. memory operand) attributes attached to
29 /// a memory access operand/pointer. Specifically, parses the following syntax:
30 /// (`[` memory-access `]`)?
31 /// where:
32 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33 /// integer-literal | `"NonTemporal"`
34 template <typename MemoryOpTy>
36  OperationState &state) {
37  // Parse an optional list of attributes staring with '['
38  if (parser.parseOptionalLSquare()) {
39  // Nothing to do
40  return success();
41  }
42 
43  spirv::MemoryAccess memoryAccessAttr;
44  StringAttr memoryAccessAttrName =
45  MemoryOpTy::getMemoryAccessAttrName(state.name);
46  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47  memoryAccessAttr, parser, state, memoryAccessAttrName))
48  return failure();
49 
50  if (spirv::bitEnumContainsAll(memoryAccessAttr,
51  spirv::MemoryAccess::Aligned)) {
52  // Parse integer attribute for alignment.
53  Attribute alignmentAttr;
54  StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55  Type i32Type = parser.getBuilder().getIntegerType(32);
56  if (parser.parseComma() ||
57  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58  state.attributes)) {
59  return failure();
60  }
61  }
62  return parser.parseRSquare();
63 }
64 
65 // TODO Make sure to merge this and the previous function into one template
66 // parameterized by memory access attribute name and alignment. Doing so now
67 // results in VS2017 in producing an internal error (at the call site) that's
68 // not detailed enough to understand what is happening.
69 template <typename MemoryOpTy>
71  OperationState &state) {
72  // Parse an optional list of attributes staring with '['
73  if (parser.parseOptionalLSquare()) {
74  // Nothing to do
75  return success();
76  }
77 
78  spirv::MemoryAccess memoryAccessAttr;
79  StringRef memoryAccessAttrName =
80  MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82  memoryAccessAttr, parser, state, memoryAccessAttrName))
83  return failure();
84 
85  if (spirv::bitEnumContainsAll(memoryAccessAttr,
86  spirv::MemoryAccess::Aligned)) {
87  // Parse integer attribute for alignment.
88  Attribute alignmentAttr;
89  StringAttr alignmentAttrName =
90  MemoryOpTy::getSourceAlignmentAttrName(state.name);
91  Type i32Type = parser.getBuilder().getIntegerType(32);
92  if (parser.parseComma() ||
93  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94  state.attributes)) {
95  return failure();
96  }
97  }
98  return parser.parseRSquare();
99 }
100 
101 // TODO Make sure to merge this and the previous function into one template
102 // parameterized by memory access attribute name and alignment. Doing so now
103 // results in VS2017 in producing an internal error (at the call site) that's
104 // not detailed enough to understand what is happening.
105 template <typename MemoryOpTy>
107  MemoryOpTy memoryOp, OpAsmPrinter &printer,
108  SmallVectorImpl<StringRef> &elidedAttrs,
109  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111 
112  printer << ", ";
113 
114  // Print optional memory access attribute.
115  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116  : memoryOp.getMemoryAccess())) {
117  elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118 
119  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120 
121  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122  // Print integer alignment attribute.
123  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124  : memoryOp.getAlignment())) {
125  elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126  printer << ", " << *alignment;
127  }
128  }
129  printer << "]";
130  }
131  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132 }
133 
134 template <typename MemoryOpTy>
136  MemoryOpTy memoryOp, OpAsmPrinter &printer,
137  SmallVectorImpl<StringRef> &elidedAttrs,
138  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140  // Print optional memory access attribute.
141  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142  : memoryOp.getMemoryAccess())) {
143  elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144 
145  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146 
147  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148  // Print integer alignment attribute.
149  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150  : memoryOp.getAlignment())) {
151  elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152  printer << ", " << *alignment;
153  }
154  }
155  printer << "]";
156  }
157  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158 }
159 
160 template <typename LoadStoreOpTy>
161 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162  Value val) {
163  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164  // type of the pointer and the type of the value are the same
165  //
166  // TODO: Check that the value type satisfies restrictions of
167  // SPIR-V OpLoad/OpStore operations
168  if (val.getType() !=
169  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
170  return op.emitOpError("mismatch in result type and pointer type");
171  }
172  return success();
173 }
174 
175 template <typename MemoryOpTy>
176 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177  // ODS checks for attributes values. Just need to verify that if the
178  // memory-access attribute is Aligned, then the alignment attribute must be
179  // present.
180  auto *op = memoryOp.getOperation();
181  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182  if (!memAccessAttr) {
183  // Alignment attribute shouldn't be present if memory access attribute is
184  // not present.
185  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186  return memoryOp.emitOpError(
187  "invalid alignment specification without aligned memory access "
188  "specification");
189  }
190  return success();
191  }
192 
193  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194 
195  if (!memAccess) {
196  return memoryOp.emitOpError("invalid memory access specifier: ")
197  << memAccessAttr;
198  }
199 
200  if (spirv::bitEnumContainsAll(memAccess.getValue(),
201  spirv::MemoryAccess::Aligned)) {
202  if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203  return memoryOp.emitOpError("missing alignment value");
204  }
205  } else {
206  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207  return memoryOp.emitOpError(
208  "invalid alignment specification with non-aligned memory access "
209  "specification");
210  }
211  }
212  return success();
213 }
214 
215 // TODO Make sure to merge this and the previous function into one template
216 // parameterized by memory access attribute name and alignment. Doing so now
217 // results in VS2017 in producing an internal error (at the call site) that's
218 // not detailed enough to understand what is happening.
219 template <typename MemoryOpTy>
220 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
221  // ODS checks for attributes values. Just need to verify that if the
222  // memory-access attribute is Aligned, then the alignment attribute must be
223  // present.
224  auto *op = memoryOp.getOperation();
225  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226  if (!memAccessAttr) {
227  // Alignment attribute shouldn't be present if memory access attribute is
228  // not present.
229  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230  return memoryOp.emitOpError(
231  "invalid alignment specification without aligned memory access "
232  "specification");
233  }
234  return success();
235  }
236 
237  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238 
239  if (!memAccess) {
240  return memoryOp.emitOpError("invalid memory access specifier: ")
241  << memAccess;
242  }
243 
244  if (spirv::bitEnumContainsAll(memAccess.getValue(),
245  spirv::MemoryAccess::Aligned)) {
246  if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247  return memoryOp.emitOpError("missing alignment value");
248  }
249  } else {
250  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251  return memoryOp.emitOpError(
252  "invalid alignment specification with non-aligned memory access "
253  "specification");
254  }
255  }
256  return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spirv.AccessChainOp
261 //===----------------------------------------------------------------------===//
262 
263 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
264  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
265  if (!ptrType) {
266  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267  "to composite type, but provided ")
268  << type;
269  return nullptr;
270  }
271 
272  auto resultType = ptrType.getPointeeType();
273  auto resultStorageClass = ptrType.getStorageClass();
274  int32_t index = 0;
275 
276  for (auto indexSSA : indices) {
277  auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
278  if (!cType) {
279  emitError(
280  baseLoc,
281  "'spirv.AccessChain' op cannot extract from non-composite type ")
282  << resultType << " with index " << index;
283  return nullptr;
284  }
285  index = 0;
286  if (llvm::isa<spirv::StructType>(resultType)) {
287  Operation *op = indexSSA.getDefiningOp();
288  if (!op) {
289  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290  "integer spirv.Constant to access "
291  "element of spirv.struct");
292  return nullptr;
293  }
294 
295  // TODO: this should be relaxed to allow
296  // integer literals of other bitwidths.
297  if (failed(spirv::extractValueFromConstOp(op, index))) {
298  emitError(
299  baseLoc,
300  "'spirv.AccessChain' index must be an integer spirv.Constant to "
301  "access element of spirv.struct, but provided ")
302  << op->getName();
303  return nullptr;
304  }
305  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306  emitError(baseLoc, "'spirv.AccessChain' op index ")
307  << index << " out of bounds for " << resultType;
308  return nullptr;
309  }
310  }
311  resultType = cType.getElementType(index);
312  }
313  return spirv::PointerType::get(resultType, resultStorageClass);
314 }
315 
316 void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317  Value basePtr, ValueRange indices) {
318  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319  assert(type && "Unable to deduce return type based on basePtr and indices");
320  build(builder, state, type, basePtr, indices);
321 }
322 
323 template <typename Op>
324 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
325  printer << ' ' << op.getBasePtr() << '[' << indices
326  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327 }
328 
329 template <typename Op>
330 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
331  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332  indices, accessChainOp.getLoc());
333  if (!resultType)
334  return failure();
335 
336  auto providedResultType =
337  llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338  if (!providedResultType)
339  return accessChainOp.emitOpError(
340  "result type must be a pointer, but provided")
341  << providedResultType;
342 
343  if (resultType != providedResultType)
344  return accessChainOp.emitOpError("invalid result type: expected ")
345  << resultType << ", but provided " << providedResultType;
346 
347  return success();
348 }
349 
350 LogicalResult AccessChainOp::verify() {
351  return verifyAccessChain(*this, getIndices());
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // spirv.LoadOp
356 //===----------------------------------------------------------------------===//
357 
358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359  MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362  alignment);
363 }
364 
365 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366  // Parse the storage class specification
367  spirv::StorageClass storageClass;
368  OpAsmParser::UnresolvedOperand ptrInfo;
369  Type elementType;
370  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371  parseMemoryAccessAttributes<LoadOp>(parser, result) ||
372  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373  parser.parseType(elementType)) {
374  return failure();
375  }
376 
377  auto ptrType = spirv::PointerType::get(elementType, storageClass);
378  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379  return failure();
380  }
381 
382  result.addTypes(elementType);
383  return success();
384 }
385 
386 void LoadOp::print(OpAsmPrinter &printer) {
387  SmallVector<StringRef, 4> elidedAttrs;
388  StringRef sc = stringifyStorageClass(
389  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
390  printer << " \"" << sc << "\" " << getPtr();
391 
392  printMemoryAccessAttribute(*this, printer, elidedAttrs);
393 
394  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395  printer << " : " << getType();
396 }
397 
398 LogicalResult LoadOp::verify() {
399  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
400  // type with fixed size; i.e., it cannot be, nor include, any
401  // OpTypeRuntimeArray types."
402  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
403  return failure();
404  }
405  return verifyMemoryAccessAttribute(*this);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // spirv.StoreOp
410 //===----------------------------------------------------------------------===//
411 
412 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413  // Parse the storage class specification
414  spirv::StorageClass storageClass;
415  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416  auto loc = parser.getCurrentLocation();
417  Type elementType;
418  if (parseEnumStrAttr(storageClass, parser) ||
419  parser.parseOperandList(operandInfo, 2) ||
420  parseMemoryAccessAttributes<StoreOp>(parser, result) ||
421  parser.parseColon() || parser.parseType(elementType)) {
422  return failure();
423  }
424 
425  auto ptrType = spirv::PointerType::get(elementType, storageClass);
426  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427  result.operands)) {
428  return failure();
429  }
430  return success();
431 }
432 
433 void StoreOp::print(OpAsmPrinter &printer) {
434  SmallVector<StringRef, 4> elidedAttrs;
435  StringRef sc = stringifyStorageClass(
436  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
437  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438 
439  printMemoryAccessAttribute(*this, printer, elidedAttrs);
440 
441  printer << " : " << getValue().getType();
442  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443 }
444 
445 LogicalResult StoreOp::verify() {
446  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
447  // OpTypePointer whose Type operand is the same as the type of Object."
448  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
449  return failure();
450  return verifyMemoryAccessAttribute(*this);
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // spirv.CopyMemory
455 //===----------------------------------------------------------------------===//
456 
457 void CopyMemoryOp::print(OpAsmPrinter &printer) {
458  printer << ' ';
459 
460  StringRef targetStorageClass = stringifyStorageClass(
461  llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
462  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463 
464  StringRef sourceStorageClass = stringifyStorageClass(
465  llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
466  printer << " \"" << sourceStorageClass << "\" " << getSource();
467 
468  SmallVector<StringRef, 4> elidedAttrs;
469  printMemoryAccessAttribute(*this, printer, elidedAttrs);
470  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
471  getSourceMemoryAccess(),
472  getSourceAlignment());
473 
474  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475 
476  Type pointeeType =
477  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
478  printer << " : " << pointeeType;
479 }
480 
481 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482  spirv::StorageClass targetStorageClass;
483  OpAsmParser::UnresolvedOperand targetPtrInfo;
484 
485  spirv::StorageClass sourceStorageClass;
486  OpAsmParser::UnresolvedOperand sourcePtrInfo;
487 
488  Type elementType;
489 
490  if (parseEnumStrAttr(targetStorageClass, parser) ||
491  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
492  parseEnumStrAttr(sourceStorageClass, parser) ||
493  parser.parseOperand(sourcePtrInfo) ||
494  parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
495  return failure();
496  }
497 
498  if (!parser.parseOptionalComma()) {
499  // Parse 2nd memory access attributes.
500  if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
501  return failure();
502  }
503  }
504 
505  if (parser.parseColon() || parser.parseType(elementType))
506  return failure();
507 
508  if (parser.parseOptionalAttrDict(result.attributes))
509  return failure();
510 
511  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
512  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
513 
514  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516  return failure();
517  }
518 
519  return success();
520 }
521 
522 LogicalResult CopyMemoryOp::verify() {
523  Type targetType =
524  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
525 
526  Type sourceType =
527  llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
528 
529  if (targetType != sourceType)
530  return emitOpError("both operands must be pointers to the same type");
531 
532  if (failed(verifyMemoryAccessAttribute(*this)))
533  return failure();
534 
535  // TODO - According to the spec:
536  //
537  // If two masks are present, the first applies to Target and cannot include
538  // MakePointerVisible, and the second applies to Source and cannot include
539  // MakePointerAvailable.
540  //
541  // Add such verification here.
542 
543  return verifySourceMemoryAccessAttribute(*this);
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // spirv.InBoundsPtrAccessChainOp
548 //===----------------------------------------------------------------------===//
549 
550 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551  Value basePtr, Value element,
552  ValueRange indices) {
553  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
554  assert(type && "Unable to deduce return type based on basePtr and indices");
555  build(builder, state, type, basePtr, element, indices);
556 }
557 
558 LogicalResult InBoundsPtrAccessChainOp::verify() {
559  return verifyAccessChain(*this, getIndices());
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // spirv.PtrAccessChainOp
564 //===----------------------------------------------------------------------===//
565 
566 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567  Value basePtr, Value element, ValueRange indices) {
568  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
569  assert(type && "Unable to deduce return type based on basePtr and indices");
570  build(builder, state, type, basePtr, element, indices);
571 }
572 
573 LogicalResult PtrAccessChainOp::verify() {
574  return verifyAccessChain(*this, getIndices());
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // spirv.Variable
579 //===----------------------------------------------------------------------===//
580 
581 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
582  // Parse optional initializer
583  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
584  if (succeeded(parser.parseOptionalKeyword("init"))) {
585  initInfo = OpAsmParser::UnresolvedOperand();
586  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587  parser.parseRParen())
588  return failure();
589  }
590 
591  if (parseVariableDecorations(parser, result)) {
592  return failure();
593  }
594 
595  // Parse result pointer type
596  Type type;
597  if (parser.parseColon())
598  return failure();
599  auto loc = parser.getCurrentLocation();
600  if (parser.parseType(type))
601  return failure();
602 
603  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
604  if (!ptrType)
605  return parser.emitError(loc, "expected spirv.ptr type");
606  result.addTypes(ptrType);
607 
608  // Resolve the initializer operand
609  if (initInfo) {
610  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
611  result.operands))
612  return failure();
613  }
614 
615  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
616  ptrType.getStorageClass());
617  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
618 
619  return success();
620 }
621 
622 void VariableOp::print(OpAsmPrinter &printer) {
623  SmallVector<StringRef, 4> elidedAttrs{
624  spirv::attributeName<spirv::StorageClass>()};
625  // Print optional initializer
626  if (getNumOperands() != 0)
627  printer << " init(" << getInitializer() << ")";
628 
629  printVariableDecorations(*this, printer, elidedAttrs);
630  printer << " : " << getType();
631 }
632 
633 LogicalResult VariableOp::verify() {
634  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
635  // object. It cannot be Generic. It must be the same as the Storage Class
636  // operand of the Result Type."
637  if (getStorageClass() != spirv::StorageClass::Function) {
638  return emitOpError(
639  "can only be used to model function-level variables. Use "
640  "spirv.GlobalVariable for module-level variables.");
641  }
642 
643  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
644  if (getStorageClass() != pointerType.getStorageClass())
645  return emitOpError(
646  "storage class must match result pointer's storage class");
647 
648  if (getNumOperands() != 0) {
649  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
650  // a global (module scope) OpVariable instruction".
651  auto *initOp = getOperand(0).getDefiningOp();
652  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
653  spirv::ReferenceOfOp, // for spec constant
654  spirv::AddressOfOp>(initOp))
655  return emitOpError("initializer must be the result of a "
656  "constant or spirv.GlobalVariable op");
657  }
658 
659  auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
660  return op->getAttr(
661  llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
662  };
663 
664  // TODO: generate these strings using ODS.
665  for (auto decoration :
666  {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667  spirv::Decoration::BuiltIn}) {
668  if (auto attr = getDecorationAttr(decoration))
669  return emitOpError("cannot have '")
670  << llvm::convertToSnakeFromCamelCase(
671  stringifyDecoration(decoration))
672  << "' attribute (only allowed in spirv.GlobalVariable)";
673  }
674 
675  // From SPV_KHR_physical_storage_buffer:
676  // > If an OpVariable's pointee type is a pointer (or array of pointers) in
677  // > PhysicalStorageBuffer storage class, then the variable must be decorated
678  // > with exactly one of AliasedPointer or RestrictPointer.
679  auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
680  if (!pointeePtrType) {
681  if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
682  pointeePtrType =
683  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
684  }
685  }
686 
687  if (pointeePtrType && pointeePtrType.getStorageClass() ==
688  spirv::StorageClass::PhysicalStorageBuffer) {
689  bool hasAliasedPtr =
690  getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
691  bool hasRestrictPtr =
692  getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
693 
694  if (!hasAliasedPtr && !hasRestrictPtr)
695  return emitOpError() << " with physical buffer pointer must be decorated "
696  "either 'AliasedPointer' or 'RestrictPointer'";
697 
698  if (hasAliasedPtr && hasRestrictPtr)
699  return emitOpError()
700  << " with physical buffer pointer must have exactly one "
701  "aliasing decoration";
702  }
703 
704  return success();
705 }
706 
707 } // namespace mlir::spirv
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:220
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:106
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
Definition: MemoryOps.cpp:35
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:94
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: MemoryOps.cpp:161
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:176
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: MemoryOps.cpp:324
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:135
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: MemoryOps.cpp:330
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This represents an operation in an abstracted form, suitable for use with the builder APIs.