MLIR  19.0.0git
MemoryOps.cpp
Go to the documentation of this file.
1 //===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the memory operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
19 #include "mlir/IR/Diagnostics.h"
20 
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23 
24 using namespace mlir::spirv::AttrNames;
25 
26 namespace mlir::spirv {
27 
28 /// Parses optional memory access (a.k.a. memory operand) attributes attached to
29 /// a memory access operand/pointer. Specifically, parses the following syntax:
30 /// (`[` memory-access `]`)?
31 /// where:
32 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33 /// integer-literal | `"NonTemporal"`
34 template <typename MemoryOpTy>
36  OperationState &state) {
37  // Parse an optional list of attributes staring with '['
38  if (parser.parseOptionalLSquare()) {
39  // Nothing to do
40  return success();
41  }
42 
43  spirv::MemoryAccess memoryAccessAttr;
44  StringAttr memoryAccessAttrName =
45  MemoryOpTy::getMemoryAccessAttrName(state.name);
46  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47  memoryAccessAttr, parser, state, memoryAccessAttrName))
48  return failure();
49 
50  if (spirv::bitEnumContainsAll(memoryAccessAttr,
51  spirv::MemoryAccess::Aligned)) {
52  // Parse integer attribute for alignment.
53  Attribute alignmentAttr;
54  StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55  Type i32Type = parser.getBuilder().getIntegerType(32);
56  if (parser.parseComma() ||
57  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58  state.attributes)) {
59  return failure();
60  }
61  }
62  return parser.parseRSquare();
63 }
64 
65 // TODO Make sure to merge this and the previous function into one template
66 // parameterized by memory access attribute name and alignment. Doing so now
67 // results in VS2017 in producing an internal error (at the call site) that's
68 // not detailed enough to understand what is happening.
69 template <typename MemoryOpTy>
71  OperationState &state) {
72  // Parse an optional list of attributes staring with '['
73  if (parser.parseOptionalLSquare()) {
74  // Nothing to do
75  return success();
76  }
77 
78  spirv::MemoryAccess memoryAccessAttr;
79  StringRef memoryAccessAttrName =
80  MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82  memoryAccessAttr, parser, state, memoryAccessAttrName))
83  return failure();
84 
85  if (spirv::bitEnumContainsAll(memoryAccessAttr,
86  spirv::MemoryAccess::Aligned)) {
87  // Parse integer attribute for alignment.
88  Attribute alignmentAttr;
89  StringAttr alignmentAttrName =
90  MemoryOpTy::getSourceAlignmentAttrName(state.name);
91  Type i32Type = parser.getBuilder().getIntegerType(32);
92  if (parser.parseComma() ||
93  parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94  state.attributes)) {
95  return failure();
96  }
97  }
98  return parser.parseRSquare();
99 }
100 
101 // TODO Make sure to merge this and the previous function into one template
102 // parameterized by memory access attribute name and alignment. Doing so now
103 // results in VS2017 in producing an internal error (at the call site) that's
104 // not detailed enough to understand what is happening.
105 template <typename MemoryOpTy>
107  MemoryOpTy memoryOp, OpAsmPrinter &printer,
108  SmallVectorImpl<StringRef> &elidedAttrs,
109  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111 
112  printer << ", ";
113 
114  // Print optional memory access attribute.
115  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116  : memoryOp.getMemoryAccess())) {
117  elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118 
119  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120 
121  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122  // Print integer alignment attribute.
123  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124  : memoryOp.getAlignment())) {
125  elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126  printer << ", " << *alignment;
127  }
128  }
129  printer << "]";
130  }
131  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132 }
133 
134 template <typename MemoryOpTy>
136  MemoryOpTy memoryOp, OpAsmPrinter &printer,
137  SmallVectorImpl<StringRef> &elidedAttrs,
138  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140  // Print optional memory access attribute.
141  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142  : memoryOp.getMemoryAccess())) {
143  elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144 
145  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146 
147  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148  // Print integer alignment attribute.
149  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150  : memoryOp.getAlignment())) {
151  elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152  printer << ", " << *alignment;
153  }
154  }
155  printer << "]";
156  }
157  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158 }
159 
160 template <typename LoadStoreOpTy>
161 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162  Value val) {
163  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164  // type of the pointer and the type of the value are the same
165  //
166  // TODO: Check that the value type satisfies restrictions of
167  // SPIR-V OpLoad/OpStore operations
168  if (val.getType() !=
169  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
170  return op.emitOpError("mismatch in result type and pointer type");
171  }
172  return success();
173 }
174 
175 template <typename MemoryOpTy>
176 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177  // ODS checks for attributes values. Just need to verify that if the
178  // memory-access attribute is Aligned, then the alignment attribute must be
179  // present.
180  auto *op = memoryOp.getOperation();
181  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182  if (!memAccessAttr) {
183  // Alignment attribute shouldn't be present if memory access attribute is
184  // not present.
185  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186  return memoryOp.emitOpError(
187  "invalid alignment specification without aligned memory access "
188  "specification");
189  }
190  return success();
191  }
192 
193  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194 
195  if (!memAccess) {
196  return memoryOp.emitOpError("invalid memory access specifier: ")
197  << memAccessAttr;
198  }
199 
200  if (spirv::bitEnumContainsAll(memAccess.getValue(),
201  spirv::MemoryAccess::Aligned)) {
202  if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203  return memoryOp.emitOpError("missing alignment value");
204  }
205  } else {
206  if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207  return memoryOp.emitOpError(
208  "invalid alignment specification with non-aligned memory access "
209  "specification");
210  }
211  }
212  return success();
213 }
214 
215 // TODO Make sure to merge this and the previous function into one template
216 // parameterized by memory access attribute name and alignment. Doing so now
217 // results in VS2017 in producing an internal error (at the call site) that's
218 // not detailed enough to understand what is happening.
219 template <typename MemoryOpTy>
221  // ODS checks for attributes values. Just need to verify that if the
222  // memory-access attribute is Aligned, then the alignment attribute must be
223  // present.
224  auto *op = memoryOp.getOperation();
225  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226  if (!memAccessAttr) {
227  // Alignment attribute shouldn't be present if memory access attribute is
228  // not present.
229  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230  return memoryOp.emitOpError(
231  "invalid alignment specification without aligned memory access "
232  "specification");
233  }
234  return success();
235  }
236 
237  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238 
239  if (!memAccess) {
240  return memoryOp.emitOpError("invalid memory access specifier: ")
241  << memAccess;
242  }
243 
244  if (spirv::bitEnumContainsAll(memAccess.getValue(),
245  spirv::MemoryAccess::Aligned)) {
246  if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247  return memoryOp.emitOpError("missing alignment value");
248  }
249  } else {
250  if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251  return memoryOp.emitOpError(
252  "invalid alignment specification with non-aligned memory access "
253  "specification");
254  }
255  }
256  return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spirv.AccessChainOp
261 //===----------------------------------------------------------------------===//
262 
263 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
264  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
265  if (!ptrType) {
266  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267  "to composite type, but provided ")
268  << type;
269  return nullptr;
270  }
271 
272  auto resultType = ptrType.getPointeeType();
273  auto resultStorageClass = ptrType.getStorageClass();
274  int32_t index = 0;
275 
276  for (auto indexSSA : indices) {
277  auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
278  if (!cType) {
279  emitError(
280  baseLoc,
281  "'spirv.AccessChain' op cannot extract from non-composite type ")
282  << resultType << " with index " << index;
283  return nullptr;
284  }
285  index = 0;
286  if (llvm::isa<spirv::StructType>(resultType)) {
287  Operation *op = indexSSA.getDefiningOp();
288  if (!op) {
289  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290  "integer spirv.Constant to access "
291  "element of spirv.struct");
292  return nullptr;
293  }
294 
295  // TODO: this should be relaxed to allow
296  // integer literals of other bitwidths.
297  if (failed(spirv::extractValueFromConstOp(op, index))) {
298  emitError(
299  baseLoc,
300  "'spirv.AccessChain' index must be an integer spirv.Constant to "
301  "access element of spirv.struct, but provided ")
302  << op->getName();
303  return nullptr;
304  }
305  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306  emitError(baseLoc, "'spirv.AccessChain' op index ")
307  << index << " out of bounds for " << resultType;
308  return nullptr;
309  }
310  }
311  resultType = cType.getElementType(index);
312  }
313  return spirv::PointerType::get(resultType, resultStorageClass);
314 }
315 
316 void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317  Value basePtr, ValueRange indices) {
318  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319  assert(type && "Unable to deduce return type based on basePtr and indices");
320  build(builder, state, type, basePtr, indices);
321 }
322 
323 ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
324  OpAsmParser::UnresolvedOperand ptrInfo;
325  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
326  Type type;
327  auto loc = parser.getCurrentLocation();
328  SmallVector<Type, 4> indicesTypes;
329 
330  if (parser.parseOperand(ptrInfo) ||
331  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
332  parser.parseColonType(type) ||
333  parser.resolveOperand(ptrInfo, type, result.operands)) {
334  return failure();
335  }
336 
337  // Check that the provided indices list is not empty before parsing their
338  // type list.
339  if (indicesInfo.empty()) {
340  return mlir::emitError(result.location,
341  "'spirv.AccessChain' op expected at "
342  "least one index ");
343  }
344 
345  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
346  return failure();
347 
348  // Check that the indices types list is not empty and that it has a one-to-one
349  // mapping to the provided indices.
350  if (indicesTypes.size() != indicesInfo.size()) {
351  return mlir::emitError(
352  result.location, "'spirv.AccessChain' op indices types' count must be "
353  "equal to indices info count");
354  }
355 
356  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
357  return failure();
358 
359  auto resultType = getElementPtrType(
360  type, llvm::ArrayRef(result.operands).drop_front(), result.location);
361  if (!resultType) {
362  return failure();
363  }
364 
365  result.addTypes(resultType);
366  return success();
367 }
368 
369 template <typename Op>
370 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
371  printer << ' ' << op.getBasePtr() << '[' << indices
372  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
373 }
374 
376  printAccessChain(*this, getIndices(), printer);
377 }
378 
379 template <typename Op>
380 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
381  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
382  indices, accessChainOp.getLoc());
383  if (!resultType)
384  return failure();
385 
386  auto providedResultType =
387  llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
388  if (!providedResultType)
389  return accessChainOp.emitOpError(
390  "result type must be a pointer, but provided")
391  << providedResultType;
392 
393  if (resultType != providedResultType)
394  return accessChainOp.emitOpError("invalid result type: expected ")
395  << resultType << ", but provided " << providedResultType;
396 
397  return success();
398 }
399 
401  return verifyAccessChain(*this, getIndices());
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // spirv.LoadOp
406 //===----------------------------------------------------------------------===//
407 
408 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
409  MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
410  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
411  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
412  alignment);
413 }
414 
415 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
416  // Parse the storage class specification
417  spirv::StorageClass storageClass;
418  OpAsmParser::UnresolvedOperand ptrInfo;
419  Type elementType;
420  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
421  parseMemoryAccessAttributes<LoadOp>(parser, result) ||
422  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
423  parser.parseType(elementType)) {
424  return failure();
425  }
426 
427  auto ptrType = spirv::PointerType::get(elementType, storageClass);
428  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
429  return failure();
430  }
431 
432  result.addTypes(elementType);
433  return success();
434 }
435 
436 void LoadOp::print(OpAsmPrinter &printer) {
437  SmallVector<StringRef, 4> elidedAttrs;
438  StringRef sc = stringifyStorageClass(
439  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
440  printer << " \"" << sc << "\" " << getPtr();
441 
442  printMemoryAccessAttribute(*this, printer, elidedAttrs);
443 
444  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
445  printer << " : " << getType();
446 }
447 
448 LogicalResult LoadOp::verify() {
449  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
450  // type with fixed size; i.e., it cannot be, nor include, any
451  // OpTypeRuntimeArray types."
452  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
453  return failure();
454  }
455  return verifyMemoryAccessAttribute(*this);
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // spirv.StoreOp
460 //===----------------------------------------------------------------------===//
461 
462 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
463  // Parse the storage class specification
464  spirv::StorageClass storageClass;
465  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
466  auto loc = parser.getCurrentLocation();
467  Type elementType;
468  if (parseEnumStrAttr(storageClass, parser) ||
469  parser.parseOperandList(operandInfo, 2) ||
470  parseMemoryAccessAttributes<StoreOp>(parser, result) ||
471  parser.parseColon() || parser.parseType(elementType)) {
472  return failure();
473  }
474 
475  auto ptrType = spirv::PointerType::get(elementType, storageClass);
476  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
477  result.operands)) {
478  return failure();
479  }
480  return success();
481 }
482 
483 void StoreOp::print(OpAsmPrinter &printer) {
484  SmallVector<StringRef, 4> elidedAttrs;
485  StringRef sc = stringifyStorageClass(
486  llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
487  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
488 
489  printMemoryAccessAttribute(*this, printer, elidedAttrs);
490 
491  printer << " : " << getValue().getType();
492  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
493 }
494 
495 LogicalResult StoreOp::verify() {
496  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
497  // OpTypePointer whose Type operand is the same as the type of Object."
498  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
499  return failure();
500  return verifyMemoryAccessAttribute(*this);
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // spirv.CopyMemory
505 //===----------------------------------------------------------------------===//
506 
507 void CopyMemoryOp::print(OpAsmPrinter &printer) {
508  printer << ' ';
509 
510  StringRef targetStorageClass = stringifyStorageClass(
511  llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
512  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
513 
514  StringRef sourceStorageClass = stringifyStorageClass(
515  llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
516  printer << " \"" << sourceStorageClass << "\" " << getSource();
517 
518  SmallVector<StringRef, 4> elidedAttrs;
519  printMemoryAccessAttribute(*this, printer, elidedAttrs);
520  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
521  getSourceMemoryAccess(),
522  getSourceAlignment());
523 
524  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
525 
526  Type pointeeType =
527  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
528  printer << " : " << pointeeType;
529 }
530 
531 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
532  spirv::StorageClass targetStorageClass;
533  OpAsmParser::UnresolvedOperand targetPtrInfo;
534 
535  spirv::StorageClass sourceStorageClass;
536  OpAsmParser::UnresolvedOperand sourcePtrInfo;
537 
538  Type elementType;
539 
540  if (parseEnumStrAttr(targetStorageClass, parser) ||
541  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
542  parseEnumStrAttr(sourceStorageClass, parser) ||
543  parser.parseOperand(sourcePtrInfo) ||
544  parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
545  return failure();
546  }
547 
548  if (!parser.parseOptionalComma()) {
549  // Parse 2nd memory access attributes.
550  if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
551  return failure();
552  }
553  }
554 
555  if (parser.parseColon() || parser.parseType(elementType))
556  return failure();
557 
558  if (parser.parseOptionalAttrDict(result.attributes))
559  return failure();
560 
561  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
562  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
563 
564  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
565  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
566  return failure();
567  }
568 
569  return success();
570 }
571 
572 LogicalResult CopyMemoryOp::verify() {
573  Type targetType =
574  llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
575 
576  Type sourceType =
577  llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
578 
579  if (targetType != sourceType)
580  return emitOpError("both operands must be pointers to the same type");
581 
583  return failure();
584 
585  // TODO - According to the spec:
586  //
587  // If two masks are present, the first applies to Target and cannot include
588  // MakePointerVisible, and the second applies to Source and cannot include
589  // MakePointerAvailable.
590  //
591  // Add such verification here.
592 
593  return verifySourceMemoryAccessAttribute(*this);
594 }
595 
596 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
597  OpAsmParser &parser,
598  OperationState &state) {
601  Type type;
602  auto loc = parser.getCurrentLocation();
603  SmallVector<Type, 4> indicesTypes;
604 
605  if (parser.parseOperand(ptrInfo) ||
606  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
607  parser.parseColonType(type) ||
608  parser.resolveOperand(ptrInfo, type, state.operands))
609  return failure();
610 
611  // Check that the provided indices list is not empty before parsing their
612  // type list.
613  if (indicesInfo.empty())
614  return emitError(state.location) << opName << " expected element";
615 
616  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
617  return failure();
618 
619  // Check that the indices types list is not empty and that it has a one-to-one
620  // mapping to the provided indices.
621  if (indicesTypes.size() != indicesInfo.size())
622  return emitError(state.location)
623  << opName
624  << " indices types' count must be equal to indices info count";
625 
626  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
627  return failure();
628 
629  auto resultType = getElementPtrType(
630  type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
631  if (!resultType)
632  return failure();
633 
634  state.addTypes(resultType);
635  return success();
636 }
637 
638 template <typename Op>
639 static auto concatElemAndIndices(Op op) {
640  SmallVector<Value> ret(op.getIndices().size() + 1);
641  ret[0] = op.getElement();
642  llvm::copy(op.getIndices(), ret.begin() + 1);
643  return ret;
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // spirv.InBoundsPtrAccessChainOp
648 //===----------------------------------------------------------------------===//
649 
650 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
651  Value basePtr, Value element,
652  ValueRange indices) {
653  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
654  assert(type && "Unable to deduce return type based on basePtr and indices");
655  build(builder, state, type, basePtr, element, indices);
656 }
657 
658 ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
659  OperationState &result) {
661  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
662 }
663 
664 void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
665  printAccessChain(*this, concatElemAndIndices(*this), printer);
666 }
667 
668 LogicalResult InBoundsPtrAccessChainOp::verify() {
669  return verifyAccessChain(*this, getIndices());
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // spirv.PtrAccessChainOp
674 //===----------------------------------------------------------------------===//
675 
676 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
677  Value basePtr, Value element, ValueRange indices) {
678  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
679  assert(type && "Unable to deduce return type based on basePtr and indices");
680  build(builder, state, type, basePtr, element, indices);
681 }
682 
683 ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
684  OperationState &result) {
685  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
686  parser, result);
687 }
688 
689 void PtrAccessChainOp::print(OpAsmPrinter &printer) {
690  printAccessChain(*this, concatElemAndIndices(*this), printer);
691 }
692 
693 LogicalResult PtrAccessChainOp::verify() {
694  return verifyAccessChain(*this, getIndices());
695 }
696 
697 //===----------------------------------------------------------------------===//
698 // spirv.Variable
699 //===----------------------------------------------------------------------===//
700 
701 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
702  // Parse optional initializer
703  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
704  if (succeeded(parser.parseOptionalKeyword("init"))) {
705  initInfo = OpAsmParser::UnresolvedOperand();
706  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
707  parser.parseRParen())
708  return failure();
709  }
710 
711  if (parseVariableDecorations(parser, result)) {
712  return failure();
713  }
714 
715  // Parse result pointer type
716  Type type;
717  if (parser.parseColon())
718  return failure();
719  auto loc = parser.getCurrentLocation();
720  if (parser.parseType(type))
721  return failure();
722 
723  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
724  if (!ptrType)
725  return parser.emitError(loc, "expected spirv.ptr type");
726  result.addTypes(ptrType);
727 
728  // Resolve the initializer operand
729  if (initInfo) {
730  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
731  result.operands))
732  return failure();
733  }
734 
735  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
736  ptrType.getStorageClass());
737  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
738 
739  return success();
740 }
741 
742 void VariableOp::print(OpAsmPrinter &printer) {
743  SmallVector<StringRef, 4> elidedAttrs{
744  spirv::attributeName<spirv::StorageClass>()};
745  // Print optional initializer
746  if (getNumOperands() != 0)
747  printer << " init(" << getInitializer() << ")";
748 
749  printVariableDecorations(*this, printer, elidedAttrs);
750  printer << " : " << getType();
751 }
752 
753 LogicalResult VariableOp::verify() {
754  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
755  // object. It cannot be Generic. It must be the same as the Storage Class
756  // operand of the Result Type."
757  if (getStorageClass() != spirv::StorageClass::Function) {
758  return emitOpError(
759  "can only be used to model function-level variables. Use "
760  "spirv.GlobalVariable for module-level variables.");
761  }
762 
763  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
764  if (getStorageClass() != pointerType.getStorageClass())
765  return emitOpError(
766  "storage class must match result pointer's storage class");
767 
768  if (getNumOperands() != 0) {
769  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
770  // a global (module scope) OpVariable instruction".
771  auto *initOp = getOperand(0).getDefiningOp();
772  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
773  spirv::ReferenceOfOp, // for spec constant
774  spirv::AddressOfOp>(initOp))
775  return emitOpError("initializer must be the result of a "
776  "constant or spirv.GlobalVariable op");
777  }
778 
779  auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
780  return op->getAttr(
781  llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
782  };
783 
784  // TODO: generate these strings using ODS.
785  for (auto decoration :
786  {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
787  spirv::Decoration::BuiltIn}) {
788  if (auto attr = getDecorationAttr(decoration))
789  return emitOpError("cannot have '")
790  << llvm::convertToSnakeFromCamelCase(
791  stringifyDecoration(decoration))
792  << "' attribute (only allowed in spirv.GlobalVariable)";
793  }
794 
795  // From SPV_KHR_physical_storage_buffer:
796  // > If an OpVariable's pointee type is a pointer (or array of pointers) in
797  // > PhysicalStorageBuffer storage class, then the variable must be decorated
798  // > with exactly one of AliasedPointer or RestrictPointer.
799  auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
800  if (!pointeePtrType) {
801  if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
802  pointeePtrType =
803  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
804  }
805  }
806 
807  if (pointeePtrType && pointeePtrType.getStorageClass() ==
808  spirv::StorageClass::PhysicalStorageBuffer) {
809  bool hasAliasedPtr =
810  getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
811  bool hasRestrictPtr =
812  getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
813 
814  if (!hasAliasedPtr && !hasRestrictPtr)
815  return emitOpError() << " with physical buffer pointer must be decorated "
816  "either 'AliasedPointer' or 'RestrictPointer'";
817 
818  if (hasAliasedPtr && hasRestrictPtr)
819  return emitOpError()
820  << " with physical buffer pointer must have exactly one "
821  "aliasing decoration";
822  }
823 
824  return success();
825 }
826 
827 } // namespace mlir::spirv
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:76
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:209
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:830
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents success/failure for parsing-like operations that find it important to chain tog...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:220
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: MemoryOps.cpp:596
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:106
static auto concatElemAndIndices(Op op)
Definition: MemoryOps.cpp:639
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
Definition: MemoryOps.cpp:35
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: MemoryOps.cpp:161
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: MemoryOps.cpp:176
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: MemoryOps.cpp:370
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: MemoryOps.cpp:135
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: MemoryOps.cpp:380
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.