MLIR  20.0.0git
MemoryOps.cpp
Go to the documentation of this file.
1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the memory operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
19 #include "mlir/IR/Diagnostics.h"
20 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23 
24 using namespace mlir::spirv::AttrNames;
25 
26 namespace mlir::spirv {
27 
28 /// Parses optional memory access (a.k.a. memory operand) attributes attached to
29 /// a memory access operand/pointer. Specifically, parses the following syntax:
30 /// (`[` memory-access `]`)?
31 /// where:
32 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33 /// integer-literal | `"NonTemporal"`
34 template <typename MemoryOpTy>
36  OperationState &state) {
37  // Parse an optional list of attributes staring with '['
38  if (parser.parseOptionalLSquare()) {
39  // Nothing to do
40  return success();
41  }
42 
43  spirv::MemoryAccess memoryAccessAttr;
44  StringAttr memoryAccessAttrName =
45  MemoryOpTy::getMemoryAccessAttrName(state.name);
46  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47  memoryAccessAttr, parser, state, memoryAccessAttrName))
48  return failure();
49 
50  if (spirv::bitEnumContainsAll(memoryAccessAttr,
51  spirv::MemoryAccess::Aligned)) {
52  // Parse integer attribute for alignment.
53  Attribute alignmentAttr;
54  StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55  Type i32Type = parser.getBuilder().getIntegerType(32);
56  if (parser.parseComma() ||
57  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58  state.attributes)) {
59  return failure();
60  }
61  }
62  return parser.parseRSquare();
63 }
64 
65 // TODO Make sure to merge this and the previous function into one template
66 // parameterized by memory access attribute name and alignment. Doing so now
67 // results in VS2017 in producing an internal error (at the call site) that's
68 // not detailed enough to understand what is happening.
69 template <typename MemoryOpTy>
71  OperationState &state) {
72  // Parse an optional list of attributes staring with '['
73  if (parser.parseOptionalLSquare()) {
74  // Nothing to do
75  return success();
76  }
77 
78  spirv::MemoryAccess memoryAccessAttr;
79  StringRef memoryAccessAttrName =
80  MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82  memoryAccessAttr, parser, state, memoryAccessAttrName))
83  return failure();
84 
85  if (spirv::bitEnumContainsAll(memoryAccessAttr,
86  spirv::MemoryAccess::Aligned)) {
87  // Parse integer attribute for alignment.
88  Attribute alignmentAttr;
89  StringAttr alignmentAttrName =
90  MemoryOpTy::getSourceAlignmentAttrName(state.name);
91  Type i32Type = parser.getBuilder().getIntegerType(32);
92  if (parser.parseComma() ||
93  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94  state.attributes)) {
95  return failure();
96  }
97  }
98  return parser.parseRSquare();
99 }
100 
101 // TODO Make sure to merge this and the previous function into one template
102 // parameterized by memory access attribute name and alignment. Doing so now
103 // results in VS2017 in producing an internal error (at the call site) that's
104 // not detailed enough to understand what is happening.
105 template <typename MemoryOpTy>
107  MemoryOpTy memoryOp, OpAsmPrinter &printer,
108  SmallVectorImpl<StringRef> &elidedAttrs,
109  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111 
112  printer << ", ";
113 
114  // Print optional memory access attribute.
115  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116  : memoryOp.getMemoryAccess())) {
117  elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118 
119  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120 
121  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122  // Print integer alignment attribute.
123  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124  : memoryOp.getAlignment())) {
125  elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126  printer << ", " << *alignment;
127  }
128  }
129  printer << "]";
130  }
131  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132 }
133 
134 template <typename MemoryOpTy>
136  MemoryOpTy memoryOp, OpAsmPrinter &printer,
137  SmallVectorImpl<StringRef> &elidedAttrs,
138  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140  // Print optional memory access attribute.
141  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142  : memoryOp.getMemoryAccess())) {
143  elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144 
145  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146 
147  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148  // Print integer alignment attribute.
149  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150  : memoryOp.getAlignment())) {
151  elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152  printer << ", " << *alignment;
153  }
154  }
155  printer << "]";
156  }
157  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158 }
159 
160 template <typename LoadStoreOpTy>
161 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162  Value val) {
163  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164  // type of the pointer and the type of the value are the same
165  //
166  // TODO: Check that the value type satisfies restrictions of
167  // SPIR-V OpLoad/OpStore operations
168  if (val.getType() !=
169  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
170  return op.emitOpError("mismatch in result type and pointer type");
171  }
172  return success();
173 }
174 
175 template <typename MemoryOpTy>
176 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177  // ODS checks for attributes values. Just need to verify that if the
178  // memory-access attribute is Aligned, then the alignment attribute must be
179  // present.
180  auto *op = memoryOp.getOperation();
181  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182  if (!memAccessAttr) {
183  // Alignment attribute shouldn't be present if memory access attribute is
184  // not present.
185  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186  return memoryOp.emitOpError(
187  "invalid alignment specification without aligned memory access "
188  "specification");
189  }
190  return success();
191  }
192 
193  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194 
195  if (!memAccess) {
196  return memoryOp.emitOpError("invalid memory access specifier: ")
197  << memAccessAttr;
198  }
199 
200  if (spirv::bitEnumContainsAll(memAccess.getValue(),
201  spirv::MemoryAccess::Aligned)) {
202  if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203  return memoryOp.emitOpError("missing alignment value");
204  }
205  } else {
206  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207  return memoryOp.emitOpError(
208  "invalid alignment specification with non-aligned memory access "
209  "specification");
210  }
211  }
212  return success();
213 }
214 
215 // TODO Make sure to merge this and the previous function into one template
216 // parameterized by memory access attribute name and alignment. Doing so now
217 // results in VS2017 in producing an internal error (at the call site) that's
218 // not detailed enough to understand what is happening.
219 template <typename MemoryOpTy>
220 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
221  // ODS checks for attributes values. Just need to verify that if the
222  // memory-access attribute is Aligned, then the alignment attribute must be
223  // present.
224  auto *op = memoryOp.getOperation();
225  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226  if (!memAccessAttr) {
227  // Alignment attribute shouldn't be present if memory access attribute is
228  // not present.
229  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230  return memoryOp.emitOpError(
231  "invalid alignment specification without aligned memory access "
232  "specification");
233  }
234  return success();
235  }
236 
237  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238 
239  if (!memAccess) {
240  return memoryOp.emitOpError("invalid memory access specifier: ")
241  << memAccess;
242  }
243 
244  if (spirv::bitEnumContainsAll(memAccess.getValue(),
245  spirv::MemoryAccess::Aligned)) {
246  if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247  return memoryOp.emitOpError("missing alignment value");
248  }
249  } else {
250  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251  return memoryOp.emitOpError(
252  "invalid alignment specification with non-aligned memory access "
253  "specification");
254  }
255  }
256  return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spirv.AccessChainOp
261 //===----------------------------------------------------------------------===//
262 
263 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
264  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
265  if (!ptrType) {
266  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267  "to composite type, but provided ")
268  << type;
269  return nullptr;
270  }
271 
272  auto resultType = ptrType.getPointeeType();
273  auto resultStorageClass = ptrType.getStorageClass();
274  int32_t index = 0;
275 
276  for (auto indexSSA : indices) {
277  auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
278  if (!cType) {
279  emitError(
280  baseLoc,
281  "'spirv.AccessChain' op cannot extract from non-composite type ")
282  << resultType << " with index " << index;
283  return nullptr;
284  }
285  index = 0;
286  if (llvm::isa<spirv::StructType>(resultType)) {
287  Operation *op = indexSSA.getDefiningOp();
288  if (!op) {
289  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290  "integer spirv.Constant to access "
291  "element of spirv.struct");
292  return nullptr;
293  }
294 
295  // TODO: this should be relaxed to allow
296  // integer literals of other bitwidths.
297  if (failed(spirv::extractValueFromConstOp(op, index))) {
298  emitError(
299  baseLoc,
300  "'spirv.AccessChain' index must be an integer spirv.Constant to "
301  "access element of spirv.struct, but provided ")
302  << op->getName();
303  return nullptr;
304  }
305  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306  emitError(baseLoc, "'spirv.AccessChain' op index ")
307  << index << " out of bounds for " << resultType;
308  return nullptr;
309  }
310  }
311  resultType = cType.getElementType(index);
312  }
313  return spirv::PointerType::get(resultType, resultStorageClass);
314 }
315 
316 void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317  Value basePtr, ValueRange indices) {
318  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319  assert(type && "Unable to deduce return type based on basePtr and indices");
320  build(builder, state, type, basePtr, indices);
321 }
322 
323 template <typename Op>
324 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
325  printer << ' ' << op.getBasePtr() << '[' << indices
326  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327 }
328 
329 template <typename Op>
330 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
331  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332  indices, accessChainOp.getLoc());
333  if (!resultType)
334  return failure();
335 
336  auto providedResultType =
337  llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338  if (!providedResultType)
339  return accessChainOp.emitOpError(
340  "result type must be a pointer, but provided")
341  << providedResultType;
342 
343  if (resultType != providedResultType)
344  return accessChainOp.emitOpError("invalid result type: expected ")
345  << resultType << ", but provided " << providedResultType;
346 
347  return success();
348 }
349 
350 LogicalResult AccessChainOp::verify() {
351  return verifyAccessChain(*this, getIndices());
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // spirv.LoadOp
356 //===----------------------------------------------------------------------===//
357 
358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359  MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362  alignment);
363 }
364 
365 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366  // Parse the storage class specification
367  spirv::StorageClass storageClass;
368  OpAsmParser::UnresolvedOperand ptrInfo;
369  Type elementType;
370  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371  parseMemoryAccessAttributes<LoadOp>(parser, result) ||
372  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373  parser.parseType(elementType)) {
374  return failure();
375  }
376 
377  auto ptrType = spirv::PointerType::get(elementType, storageClass);
378  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379  return failure();
380  }
381 
382  result.addTypes(elementType);
383  return success();
384 }
385 
386 void LoadOp::print(OpAsmPrinter &printer) {
387  SmallVector<StringRef, 4> elidedAttrs;
388  StringRef sc = stringifyStorageClass(
389  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
390  printer << " \"" << sc << "\" " << getPtr();
391 
392  printMemoryAccessAttribute(*this, printer, elidedAttrs);
393 
394  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395  printer << " : " << getType();
396 }
397 
398 LogicalResult LoadOp::verify() {
399  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
400  // type with fixed size; i.e., it cannot be, nor include, any
401  // OpTypeRuntimeArray types."
402  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
403  return failure();
404  }
405  return verifyMemoryAccessAttribute(*this);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // spirv.StoreOp
410 //===----------------------------------------------------------------------===//
411 
412 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413  // Parse the storage class specification
414  spirv::StorageClass storageClass;
415  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416  auto loc = parser.getCurrentLocation();
417  Type elementType;
418  if (parseEnumStrAttr(storageClass, parser) ||
419  parser.parseOperandList(operandInfo, 2) ||
420  parseMemoryAccessAttributes<StoreOp>(parser, result) ||
421  parser.parseColon() || parser.parseType(elementType)) {
422  return failure();
423  }
424 
425  auto ptrType = spirv::PointerType::get(elementType, storageClass);
426  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427  result.operands)) {
428  return failure();
429  }
430  return success();
431 }
432 
433 void StoreOp::print(OpAsmPrinter &printer) {
434  SmallVector<StringRef, 4> elidedAttrs;
435  StringRef sc = stringifyStorageClass(
436  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
437  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438 
439  printMemoryAccessAttribute(*this, printer, elidedAttrs);
440 
441  printer << " : " << getValue().getType();
442  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443 }
444 
445 LogicalResult StoreOp::verify() {
446  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
447  // OpTypePointer whose Type operand is the same as the type of Object."
448  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
449  return failure();
450  return verifyMemoryAccessAttribute(*this);
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // spirv.CopyMemory
455 //===----------------------------------------------------------------------===//
456 
457 void CopyMemoryOp::print(OpAsmPrinter &printer) {
458  printer << ' ';
459 
460  StringRef targetStorageClass = stringifyStorageClass(
461  llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
462  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463 
464  StringRef sourceStorageClass = stringifyStorageClass(
465  llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
466  printer << " \"" << sourceStorageClass << "\" " << getSource();
467 
468  SmallVector<StringRef, 4> elidedAttrs;
469  printMemoryAccessAttribute(*this, printer, elidedAttrs);
470  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
471  getSourceMemoryAccess(),
472  getSourceAlignment());
473 
474  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475 
476  Type pointeeType =
477  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
478  printer << " : " << pointeeType;
479 }
480 
481 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482  spirv::StorageClass targetStorageClass;
483  OpAsmParser::UnresolvedOperand targetPtrInfo;
484 
485  spirv::StorageClass sourceStorageClass;
486  OpAsmParser::UnresolvedOperand sourcePtrInfo;
487 
488  Type elementType;
489 
490  if (parseEnumStrAttr(targetStorageClass, parser) ||
491  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
492  parseEnumStrAttr(sourceStorageClass, parser) ||
493  parser.parseOperand(sourcePtrInfo) ||
494  parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
495  return failure();
496  }
497 
498  if (!parser.parseOptionalComma()) {
499  // Parse 2nd memory access attributes.
500  if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
501  return failure();
502  }
503  }
504 
505  if (parser.parseColon() || parser.parseType(elementType))
506  return failure();
507 
508  if (parser.parseOptionalAttrDict(result.attributes))
509  return failure();
510 
511  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
512  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
513 
514  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516  return failure();
517  }
518 
519  return success();
520 }
521 
522 LogicalResult CopyMemoryOp::verify() {
523  Type targetType =
524  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
525 
526  Type sourceType =
527  llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
528 
529  if (targetType != sourceType)
530  return emitOpError("both operands must be pointers to the same type");
531 
532  if (failed(verifyMemoryAccessAttribute(*this)))
533  return failure();
534 
535  // TODO - According to the spec:
536  //
537  // If two masks are present, the first applies to Target and cannot include
538  // MakePointerVisible, and the second applies to Source and cannot include
539  // MakePointerAvailable.
540  //
541  // Add such verification here.
542 
543  return verifySourceMemoryAccessAttribute(*this);
544 }
545 
546 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
547  OpAsmParser &parser,
548  OperationState &state) {
551  Type type;
552  auto loc = parser.getCurrentLocation();
553  SmallVector<Type, 4> indicesTypes;
554 
555  if (parser.parseOperand(ptrInfo) ||
556  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
557  parser.parseColonType(type) ||
558  parser.resolveOperand(ptrInfo, type, state.operands))
559  return failure();
560 
561  // Check that the provided indices list is not empty before parsing their
562  // type list.
563  if (indicesInfo.empty())
564  return emitError(state.location) << opName << " expected element";
565 
566  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
567  return failure();
568 
569  // Check that the indices types list is not empty and that it has a one-to-one
570  // mapping to the provided indices.
571  if (indicesTypes.size() != indicesInfo.size())
572  return emitError(state.location)
573  << opName
574  << " indices types' count must be equal to indices info count";
575 
576  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
577  return failure();
578 
579  auto resultType = getElementPtrType(
580  type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
581  if (!resultType)
582  return failure();
583 
584  state.addTypes(resultType);
585  return success();
586 }
587 
588 template <typename Op>
589 static auto concatElemAndIndices(Op op) {
590  SmallVector<Value> ret(op.getIndices().size() + 1);
591  ret[0] = op.getElement();
592  llvm::copy(op.getIndices(), ret.begin() + 1);
593  return ret;
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // spirv.InBoundsPtrAccessChainOp
598 //===----------------------------------------------------------------------===//
599 
600 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
601  Value basePtr, Value element,
602  ValueRange indices) {
603  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
604  assert(type && "Unable to deduce return type based on basePtr and indices");
605  build(builder, state, type, basePtr, element, indices);
606 }
607 
608 ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
609  OperationState &result) {
611  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
612 }
613 
614 void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
615  printAccessChain(*this, concatElemAndIndices(*this), printer);
616 }
617 
618 LogicalResult InBoundsPtrAccessChainOp::verify() {
619  return verifyAccessChain(*this, getIndices());
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // spirv.PtrAccessChainOp
624 //===----------------------------------------------------------------------===//
625 
626 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
627  Value basePtr, Value element, ValueRange indices) {
628  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
629  assert(type && "Unable to deduce return type based on basePtr and indices");
630  build(builder, state, type, basePtr, element, indices);
631 }
632 
633 ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
634  OperationState &result) {
635  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
636  parser, result);
637 }
638 
639 void PtrAccessChainOp::print(OpAsmPrinter &printer) {
640  printAccessChain(*this, concatElemAndIndices(*this), printer);
641 }
642 
643 LogicalResult PtrAccessChainOp::verify() {
644  return verifyAccessChain(*this, getIndices());
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // spirv.Variable
649 //===----------------------------------------------------------------------===//
650 
651 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
652  // Parse optional initializer
653  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
654  if (succeeded(parser.parseOptionalKeyword("init"))) {
655  initInfo = OpAsmParser::UnresolvedOperand();
656  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
657  parser.parseRParen())
658  return failure();
659  }
660 
661  if (parseVariableDecorations(parser, result)) {
662  return failure();
663  }
664 
665  // Parse result pointer type
666  Type type;
667  if (parser.parseColon())
668  return failure();
669  auto loc = parser.getCurrentLocation();
670  if (parser.parseType(type))
671  return failure();
672 
673  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
674  if (!ptrType)
675  return parser.emitError(loc, "expected spirv.ptr type");
676  result.addTypes(ptrType);
677 
678  // Resolve the initializer operand
679  if (initInfo) {
680  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
681  result.operands))
682  return failure();
683  }
684 
685  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
686  ptrType.getStorageClass());
687  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
688 
689  return success();
690 }
691 
692 void VariableOp::print(OpAsmPrinter &printer) {
693  SmallVector<StringRef, 4> elidedAttrs{
694  spirv::attributeName<spirv::StorageClass>()};
695  // Print optional initializer
696  if (getNumOperands() != 0)
697  printer << " init(" << getInitializer() << ")";
698 
699  printVariableDecorations(*this, printer, elidedAttrs);
700  printer << " : " << getType();
701 }
702 
703 LogicalResult VariableOp::verify() {
704  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
705  // object. It cannot be Generic. It must be the same as the Storage Class
706  // operand of the Result Type."
707  if (getStorageClass() != spirv::StorageClass::Function) {
708  return emitOpError(
709  "can only be used to model function-level variables. Use "
710  "spirv.GlobalVariable for module-level variables.");
711  }
712 
713  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
714  if (getStorageClass() != pointerType.getStorageClass())
715  return emitOpError(
716  "storage class must match result pointer's storage class");
717 
718  if (getNumOperands() != 0) {
719  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
720  // a global (module scope) OpVariable instruction".
721  auto *initOp = getOperand(0).getDefiningOp();
722  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
723  spirv::ReferenceOfOp, // for spec constant
724  spirv::AddressOfOp>(initOp))
725  return emitOpError("initializer must be the result of a "
726  "constant or spirv.GlobalVariable op");
727  }
728 
729  auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
730  return op->getAttr(
731  llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
732  };
733 
734  // TODO: generate these strings using ODS.
735  for (auto decoration :
736  {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
737  spirv::Decoration::BuiltIn}) {
738  if (auto attr = getDecorationAttr(decoration))
739  return emitOpError("cannot have '")
740  << llvm::convertToSnakeFromCamelCase(
741  stringifyDecoration(decoration))
742  << "' attribute (only allowed in spirv.GlobalVariable)";
743  }
744 
745  // From SPV_KHR_physical_storage_buffer:
746  // > If an OpVariable's pointee type is a pointer (or array of pointers) in
747  // > PhysicalStorageBuffer storage class, then the variable must be decorated
748  // > with exactly one of AliasedPointer or RestrictPointer.
749  auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
750  if (!pointeePtrType) {
751  if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
752  pointeePtrType =
753  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
754  }
755  }
756 
757  if (pointeePtrType && pointeePtrType.getStorageClass() ==
758  spirv::StorageClass::PhysicalStorageBuffer) {
759  bool hasAliasedPtr =
760  getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
761  bool hasRestrictPtr =
762  getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
763 
764  if (!hasAliasedPtr && !hasRestrictPtr)
765  return emitOpError() << " with physical buffer pointer must be decorated "
766  "either 'AliasedPointer' or 'RestrictPointer'";
767 
768  if (hasAliasedPtr && hasRestrictPtr)
769  return emitOpError()
770  << " with physical buffer pointer must have exactly one "
771  "aliasing decoration";
772  }
773 
774  return success();
775 }
776 
777 } // namespace mlir::spirv
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:215
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:220
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:546
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:106
static auto concatElemAndIndices(Op op)
Definition: MemoryOps.cpp:589
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
Definition: MemoryOps.cpp:35
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:94
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: MemoryOps.cpp:161
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:176
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: MemoryOps.cpp:324
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:135
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: MemoryOps.cpp:330
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.