MLIR 22.0.0git
MemoryOps.cpp
Go to the documentation of this file.
1//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the memory operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
15
16#include "SPIRVOpUtils.h"
17#include "SPIRVParsingUtils.h"
19#include "mlir/IR/Diagnostics.h"
20
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/Support/Casting.h"
23
24using namespace mlir::spirv::AttrNames;
25
26namespace mlir::spirv {
27
28/// Parses optional memory access (a.k.a. memory operand) attributes attached to
29/// a memory access operand/pointer. Specifically, parses the following syntax:
30/// (`[` memory-access `]`)?
31/// where:
32/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
33/// integer-literal | `"NonTemporal"`
34template <typename MemoryOpTy>
36 OperationState &state) {
37 // Parse an optional list of attributes staring with '['
38 if (parser.parseOptionalLSquare()) {
39 // Nothing to do
40 return success();
41 }
42
43 spirv::MemoryAccess memoryAccessAttr;
44 StringAttr memoryAccessAttrName =
45 MemoryOpTy::getMemoryAccessAttrName(state.name);
47 memoryAccessAttr, parser, state, memoryAccessAttrName))
48 return failure();
49
50 if (spirv::bitEnumContainsAll(memoryAccessAttr,
51 spirv::MemoryAccess::Aligned)) {
52 // Parse integer attribute for alignment.
53 Attribute alignmentAttr;
54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
55 Type i32Type = parser.getBuilder().getIntegerType(32);
56 if (parser.parseComma() ||
57 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58 state.attributes)) {
59 return failure();
60 }
61 }
62 return parser.parseRSquare();
63}
64
65// TODO Make sure to merge this and the previous function into one template
66// parameterized by memory access attribute name and alignment. Doing so now
67// results in VS2017 in producing an internal error (at the call site) that's
68// not detailed enough to understand what is happening.
69template <typename MemoryOpTy>
71 OperationState &state) {
72 // Parse an optional list of attributes staring with '['
73 if (parser.parseOptionalLSquare()) {
74 // Nothing to do
75 return success();
76 }
77
78 spirv::MemoryAccess memoryAccessAttr;
79 StringRef memoryAccessAttrName =
80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
82 memoryAccessAttr, parser, state, memoryAccessAttrName))
83 return failure();
84
85 if (spirv::bitEnumContainsAll(memoryAccessAttr,
86 spirv::MemoryAccess::Aligned)) {
87 // Parse integer attribute for alignment.
88 Attribute alignmentAttr;
89 StringAttr alignmentAttrName =
90 MemoryOpTy::getSourceAlignmentAttrName(state.name);
91 Type i32Type = parser.getBuilder().getIntegerType(32);
92 if (parser.parseComma() ||
93 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94 state.attributes)) {
95 return failure();
96 }
97 }
98 return parser.parseRSquare();
99}
100
101// TODO Make sure to merge this and the previous function into one template
102// parameterized by memory access attribute name and alignment. Doing so now
103// results in VS2017 in producing an internal error (at the call site) that's
104// not detailed enough to understand what is happening.
105template <typename MemoryOpTy>
107 MemoryOpTy memoryOp, OpAsmPrinter &printer,
108 SmallVectorImpl<StringRef> &elidedAttrs,
109 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111
112 printer << ", ";
113
114 // Print optional memory access attribute.
115 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116 : memoryOp.getMemoryAccess())) {
117 elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118
119 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120
121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122 // Print integer alignment attribute.
123 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124 : memoryOp.getAlignment())) {
125 elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126 printer << ", " << *alignment;
127 }
128 }
129 printer << "]";
130 }
131 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
132}
133
134template <typename MemoryOpTy>
136 MemoryOpTy memoryOp, OpAsmPrinter &printer,
137 SmallVectorImpl<StringRef> &elidedAttrs,
138 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140 // Print optional memory access attribute.
141 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142 : memoryOp.getMemoryAccess())) {
143 elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144
145 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146
147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148 // Print integer alignment attribute.
149 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150 : memoryOp.getAlignment())) {
151 elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152 printer << ", " << *alignment;
153 }
154 }
155 printer << "]";
156 }
157 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
158}
159
160template <typename LoadStoreOpTy>
161static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
162 Value val) {
163 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
164 // type of the pointer and the type of the value are the same
165 //
166 // TODO: Check that the value type satisfies restrictions of
167 // SPIR-V OpLoad/OpStore operations
168 if (val.getType() !=
169 llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
170 return op.emitOpError("mismatch in result type and pointer type");
171 }
172 return success();
173}
174
175template <typename MemoryOpTy>
176static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
177 // ODS checks for attributes values. Just need to verify that if the
178 // memory-access attribute is Aligned, then the alignment attribute must be
179 // present.
180 auto *op = memoryOp.getOperation();
181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182 if (!memAccessAttr) {
183 // Alignment attribute shouldn't be present if memory access attribute is
184 // not present.
185 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186 return memoryOp.emitOpError(
187 "invalid alignment specification without aligned memory access "
188 "specification");
189 }
190 return success();
191 }
192
193 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
194
195 if (!memAccess) {
196 return memoryOp.emitOpError("invalid memory access specifier: ")
197 << memAccessAttr;
198 }
199
200 if (spirv::bitEnumContainsAll(memAccess.getValue(),
201 spirv::MemoryAccess::Aligned)) {
202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203 return memoryOp.emitOpError("missing alignment value");
204 }
205 } else {
206 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207 return memoryOp.emitOpError(
208 "invalid alignment specification with non-aligned memory access "
209 "specification");
210 }
211 }
212 return success();
213}
214
215// TODO Make sure to merge this and the previous function into one template
216// parameterized by memory access attribute name and alignment. Doing so now
217// results in VS2017 in producing an internal error (at the call site) that's
218// not detailed enough to understand what is happening.
219template <typename MemoryOpTy>
220static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
221 // ODS checks for attributes values. Just need to verify that if the
222 // memory-access attribute is Aligned, then the alignment attribute must be
223 // present.
224 auto *op = memoryOp.getOperation();
225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226 if (!memAccessAttr) {
227 // Alignment attribute shouldn't be present if memory access attribute is
228 // not present.
229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230 return memoryOp.emitOpError(
231 "invalid alignment specification without aligned memory access "
232 "specification");
233 }
234 return success();
235 }
236
237 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
238
239 if (!memAccess) {
240 return memoryOp.emitOpError("invalid memory access specifier: ")
241 << memAccess;
242 }
243
244 if (spirv::bitEnumContainsAll(memAccess.getValue(),
245 spirv::MemoryAccess::Aligned)) {
246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247 return memoryOp.emitOpError("missing alignment value");
248 }
249 } else {
250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251 return memoryOp.emitOpError(
252 "invalid alignment specification with non-aligned memory access "
253 "specification");
254 }
255 }
256 return success();
257}
258
259//===----------------------------------------------------------------------===//
260// spirv.AccessChainOp
261//===----------------------------------------------------------------------===//
262
264 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
265 if (!ptrType) {
266 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267 "to composite type, but provided ")
268 << type;
269 return nullptr;
270 }
271
272 auto resultType = ptrType.getPointeeType();
273 auto resultStorageClass = ptrType.getStorageClass();
274 int32_t index = 0;
275
276 for (auto indexSSA : indices) {
277 auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
278 if (!cType) {
279 emitError(
280 baseLoc,
281 "'spirv.AccessChain' op cannot extract from non-composite type ")
282 << resultType << " with index " << index;
283 return nullptr;
284 }
285 index = 0;
286 if (llvm::isa<spirv::StructType>(resultType)) {
287 Operation *op = indexSSA.getDefiningOp();
288 if (!op) {
289 emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290 "integer spirv.Constant to access "
291 "element of spirv.struct");
292 return nullptr;
293 }
294
295 // TODO: this should be relaxed to allow
296 // integer literals of other bitwidths.
297 if (failed(spirv::extractValueFromConstOp(op, index))) {
298 emitError(
299 baseLoc,
300 "'spirv.AccessChain' index must be an integer spirv.Constant to "
301 "access element of spirv.struct, but provided ")
302 << op->getName();
303 return nullptr;
304 }
305 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306 emitError(baseLoc, "'spirv.AccessChain' op index ")
307 << index << " out of bounds for " << resultType;
308 return nullptr;
309 }
310 }
311 resultType = cType.getElementType(index);
312 }
313 return spirv::PointerType::get(resultType, resultStorageClass);
314}
315
316void AccessChainOp::build(OpBuilder &builder, OperationState &state,
317 Value basePtr, ValueRange indices) {
318 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
319 assert(type && "Unable to deduce return type based on basePtr and indices");
320 build(builder, state, type, basePtr, indices);
321}
322
323template <typename Op>
325 printer << ' ' << op.getBasePtr() << '[' << indices
326 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327}
328
329template <typename Op>
330static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
331 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332 indices, accessChainOp.getLoc());
333 if (!resultType)
334 return failure();
335
336 auto providedResultType =
337 llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338 if (!providedResultType)
339 return accessChainOp.emitOpError(
340 "result type must be a pointer, but provided")
341 << providedResultType;
342
343 if (resultType != providedResultType)
344 return accessChainOp.emitOpError("invalid result type: expected ")
345 << resultType << ", but provided " << providedResultType;
346
347 return success();
348}
349
350LogicalResult AccessChainOp::verify() {
351 return verifyAccessChain(*this, getIndices());
352}
353
354//===----------------------------------------------------------------------===//
355// spirv.LoadOp
356//===----------------------------------------------------------------------===//
357
358void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359 MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360 auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362 alignment);
363}
364
365ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366 // Parse the storage class specification
367 spirv::StorageClass storageClass;
368 OpAsmParser::UnresolvedOperand ptrInfo;
369 Type elementType;
370 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
372 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373 parser.parseType(elementType)) {
374 return failure();
375 }
376
377 auto ptrType = spirv::PointerType::get(elementType, storageClass);
378 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379 return failure();
380 }
381
382 result.addTypes(elementType);
383 return success();
384}
385
386void LoadOp::print(OpAsmPrinter &printer) {
387 SmallVector<StringRef, 4> elidedAttrs;
388 StringRef sc = stringifyStorageClass(
389 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
390 printer << " \"" << sc << "\" " << getPtr();
391
392 printMemoryAccessAttribute(*this, printer, elidedAttrs);
393
394 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395 printer << " : " << getType();
396}
397
398LogicalResult LoadOp::verify() {
399 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
400 // type with fixed size; i.e., it cannot be, nor include, any
401 // OpTypeRuntimeArray types."
402 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
403 return failure();
404 }
405 return verifyMemoryAccessAttribute(*this);
406}
407
408//===----------------------------------------------------------------------===//
409// spirv.StoreOp
410//===----------------------------------------------------------------------===//
411
412ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413 // Parse the storage class specification
414 spirv::StorageClass storageClass;
415 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416 auto loc = parser.getCurrentLocation();
417 Type elementType;
418 if (parseEnumStrAttr(storageClass, parser) ||
419 parser.parseOperandList(operandInfo, 2) ||
421 parser.parseColon() || parser.parseType(elementType)) {
422 return failure();
423 }
424
425 auto ptrType = spirv::PointerType::get(elementType, storageClass);
426 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427 result.operands)) {
428 return failure();
429 }
430 return success();
431}
432
433void StoreOp::print(OpAsmPrinter &printer) {
434 SmallVector<StringRef, 4> elidedAttrs;
435 StringRef sc = stringifyStorageClass(
436 llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
437 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438
439 printMemoryAccessAttribute(*this, printer, elidedAttrs);
440
441 printer << " : " << getValue().getType();
442 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443}
444
445LogicalResult StoreOp::verify() {
446 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
447 // OpTypePointer whose Type operand is the same as the type of Object."
448 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
449 return failure();
450 return verifyMemoryAccessAttribute(*this);
451}
452
453//===----------------------------------------------------------------------===//
454// spirv.CopyMemory
455//===----------------------------------------------------------------------===//
456
457void CopyMemoryOp::print(OpAsmPrinter &printer) {
458 printer << ' ';
459
460 StringRef targetStorageClass = stringifyStorageClass(
461 llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
462 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463
464 StringRef sourceStorageClass = stringifyStorageClass(
465 llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
466 printer << " \"" << sourceStorageClass << "\" " << getSource();
467
468 SmallVector<StringRef, 4> elidedAttrs;
469 printMemoryAccessAttribute(*this, printer, elidedAttrs);
470 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
471 getSourceMemoryAccess(),
472 getSourceAlignment());
473
474 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475
476 Type pointeeType =
477 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
478 printer << " : " << pointeeType;
479}
480
481ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482 spirv::StorageClass targetStorageClass;
483 OpAsmParser::UnresolvedOperand targetPtrInfo;
484
485 spirv::StorageClass sourceStorageClass;
486 OpAsmParser::UnresolvedOperand sourcePtrInfo;
487
488 Type elementType;
489
490 if (parseEnumStrAttr(targetStorageClass, parser) ||
491 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
492 parseEnumStrAttr(sourceStorageClass, parser) ||
493 parser.parseOperand(sourcePtrInfo) ||
495 return failure();
496 }
497
498 if (!parser.parseOptionalComma()) {
499 // Parse 2nd memory access attributes.
501 return failure();
502 }
503 }
504
505 if (parser.parseColon() || parser.parseType(elementType))
506 return failure();
507
508 if (parser.parseOptionalAttrDict(result.attributes))
509 return failure();
510
511 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
512 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
513
514 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516 return failure();
517 }
518
519 return success();
520}
521
522LogicalResult CopyMemoryOp::verify() {
523 Type targetType =
524 llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
525
526 Type sourceType =
527 llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
528
529 if (targetType != sourceType)
530 return emitOpError("both operands must be pointers to the same type");
531
533 return failure();
534
535 // TODO - According to the spec:
536 //
537 // If two masks are present, the first applies to Target and cannot include
538 // MakePointerVisible, and the second applies to Source and cannot include
539 // MakePointerAvailable.
540 //
541 // Add such verification here.
542
544}
545
546//===----------------------------------------------------------------------===//
547// spirv.InBoundsPtrAccessChainOp
548//===----------------------------------------------------------------------===//
549
550void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551 Value basePtr, Value element,
553 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
554 assert(type && "Unable to deduce return type based on basePtr and indices");
555 build(builder, state, type, basePtr, element, indices);
556}
557
558LogicalResult InBoundsPtrAccessChainOp::verify() {
559 return verifyAccessChain(*this, getIndices());
560}
561
562//===----------------------------------------------------------------------===//
563// spirv.PtrAccessChainOp
564//===----------------------------------------------------------------------===//
565
566void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567 Value basePtr, Value element, ValueRange indices) {
568 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
569 assert(type && "Unable to deduce return type based on basePtr and indices");
570 build(builder, state, type, basePtr, element, indices);
571}
572
573LogicalResult PtrAccessChainOp::verify() {
574 return verifyAccessChain(*this, getIndices());
575}
576
577//===----------------------------------------------------------------------===//
578// spirv.Variable
579//===----------------------------------------------------------------------===//
580
581ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
582 // Parse optional initializer
583 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
584 if (succeeded(parser.parseOptionalKeyword("init"))) {
585 initInfo = OpAsmParser::UnresolvedOperand();
586 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587 parser.parseRParen())
588 return failure();
589 }
590
591 if (parseVariableDecorations(parser, result)) {
592 return failure();
593 }
594
595 // Parse result pointer type
596 Type type;
597 if (parser.parseColon())
598 return failure();
599 auto loc = parser.getCurrentLocation();
600 if (parser.parseType(type))
601 return failure();
602
603 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
604 if (!ptrType)
605 return parser.emitError(loc, "expected spirv.ptr type");
606 result.addTypes(ptrType);
607
608 // Resolve the initializer operand
609 if (initInfo) {
610 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
611 result.operands))
612 return failure();
613 }
614
615 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
616 ptrType.getStorageClass());
618
619 return success();
620}
621
622void VariableOp::print(OpAsmPrinter &printer) {
623 SmallVector<StringRef, 4> elidedAttrs{
625 // Print optional initializer
626 if (getNumOperands() != 0)
627 printer << " init(" << getInitializer() << ")";
628
629 printVariableDecorations(*this, printer, elidedAttrs);
630 printer << " : " << getType();
631}
632
633LogicalResult VariableOp::verify() {
634 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
635 // object. It cannot be Generic. It must be the same as the Storage Class
636 // operand of the Result Type."
637 if (getStorageClass() != spirv::StorageClass::Function) {
638 return emitOpError(
639 "can only be used to model function-level variables. Use "
640 "spirv.GlobalVariable for module-level variables.");
641 }
642
643 auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
644 if (getStorageClass() != pointerType.getStorageClass())
645 return emitOpError(
646 "storage class must match result pointer's storage class");
647
648 if (getNumOperands() != 0) {
649 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
650 // a global (module scope) OpVariable instruction".
651 auto *initOp = getOperand(0).getDefiningOp();
652 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
653 spirv::ReferenceOfOp, // for spec constant
654 spirv::AddressOfOp>(initOp))
655 return emitOpError("initializer must be the result of a "
656 "constant or spirv.GlobalVariable op");
657 }
658
659 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
660 return op->getAttr(
661 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
662 };
663
664 // TODO: generate these strings using ODS.
665 for (auto decoration :
666 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667 spirv::Decoration::BuiltIn}) {
668 if (auto attr = getDecorationAttr(decoration))
669 return emitOpError("cannot have '")
670 << llvm::convertToSnakeFromCamelCase(
671 stringifyDecoration(decoration))
672 << "' attribute (only allowed in spirv.GlobalVariable)";
673 }
674
675 // From SPV_KHR_physical_storage_buffer:
676 // > If an OpVariable's pointee type is a pointer (or array of pointers) in
677 // > PhysicalStorageBuffer storage class, then the variable must be decorated
678 // > with exactly one of AliasedPointer or RestrictPointer.
679 auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
680 if (!pointeePtrType) {
681 if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
682 pointeePtrType =
683 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
684 }
685 }
686
687 if (pointeePtrType && pointeePtrType.getStorageClass() ==
688 spirv::StorageClass::PhysicalStorageBuffer) {
689 bool hasAliasedPtr =
690 getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
691 bool hasRestrictPtr =
692 getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
693
694 if (!hasAliasedPtr && !hasRestrictPtr)
695 return emitOpError() << " with physical buffer pointer must be decorated "
696 "either 'AliasedPointer' or 'RestrictPointer'";
697
698 if (hasAliasedPtr && hasRestrictPtr)
699 return emitOpError()
700 << " with physical buffer pointer must have exactly one "
701 "aliasing decoration";
702 }
703
704 return success();
705}
706
707} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static PointerType get(Type pointeeType, StorageClass storageClass)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition MemoryOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
Definition MemoryOps.cpp:35
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition SPIRVOps.cpp:93
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
constexpr StringRef attributeName()
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.