21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
34 template <
typename MemoryOpTy>
43 spirv::MemoryAccess memoryAccessAttr;
44 StringAttr memoryAccessAttrName =
45 MemoryOpTy::getMemoryAccessAttrName(state.name);
46 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47 memoryAccessAttr, parser, state, memoryAccessAttrName))
50 if (spirv::bitEnumContainsAll(memoryAccessAttr,
51 spirv::MemoryAccess::Aligned)) {
54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
69 template <
typename MemoryOpTy>
78 spirv::MemoryAccess memoryAccessAttr;
79 StringRef memoryAccessAttrName =
80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82 memoryAccessAttr, parser, state, memoryAccessAttrName))
85 if (spirv::bitEnumContainsAll(memoryAccessAttr,
86 spirv::MemoryAccess::Aligned)) {
89 StringAttr alignmentAttrName =
90 MemoryOpTy::getSourceAlignmentAttrName(state.name);
105 template <
typename MemoryOpTy>
109 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
115 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116 : memoryOp.getMemoryAccess())) {
117 elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
119 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
123 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
124 : memoryOp.getAlignment())) {
125 elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126 printer <<
", " << *alignment;
131 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
134 template <
typename MemoryOpTy>
138 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
141 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142 : memoryOp.getMemoryAccess())) {
143 elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
145 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
149 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
150 : memoryOp.getAlignment())) {
151 elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152 printer <<
", " << *alignment;
157 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
160 template <
typename LoadStoreOpTy>
169 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
170 return op.emitOpError(
"mismatch in result type and pointer type");
175 template <
typename MemoryOpTy>
180 auto *op = memoryOp.getOperation();
181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182 if (!memAccessAttr) {
185 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186 return memoryOp.emitOpError(
187 "invalid alignment specification without aligned memory access "
193 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
196 return memoryOp.emitOpError(
"invalid memory access specifier: ")
200 if (spirv::bitEnumContainsAll(memAccess.getValue(),
201 spirv::MemoryAccess::Aligned)) {
202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203 return memoryOp.emitOpError(
"missing alignment value");
206 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207 return memoryOp.emitOpError(
208 "invalid alignment specification with non-aligned memory access "
219 template <
typename MemoryOpTy>
224 auto *op = memoryOp.getOperation();
225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226 if (!memAccessAttr) {
229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230 return memoryOp.emitOpError(
231 "invalid alignment specification without aligned memory access "
237 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
240 return memoryOp.emitOpError(
"invalid memory access specifier: ")
244 if (spirv::bitEnumContainsAll(memAccess.getValue(),
245 spirv::MemoryAccess::Aligned)) {
246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247 return memoryOp.emitOpError(
"missing alignment value");
250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251 return memoryOp.emitOpError(
252 "invalid alignment specification with non-aligned memory access "
264 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
266 emitError(baseLoc,
"'spirv.AccessChain' op expected a pointer "
267 "to composite type, but provided ")
272 auto resultType = ptrType.getPointeeType();
273 auto resultStorageClass = ptrType.getStorageClass();
276 for (
auto indexSSA : indices) {
277 auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
281 "'spirv.AccessChain' op cannot extract from non-composite type ")
282 << resultType <<
" with index " << index;
286 if (llvm::isa<spirv::StructType>(resultType)) {
287 Operation *op = indexSSA.getDefiningOp();
289 emitError(baseLoc,
"'spirv.AccessChain' op index must be an "
290 "integer spirv.Constant to access "
291 "element of spirv.struct");
300 "'spirv.AccessChain' index must be an integer spirv.Constant to "
301 "access element of spirv.struct, but provided ")
305 if (index < 0 ||
static_cast<uint64_t
>(index) >= cType.getNumElements()) {
306 emitError(baseLoc,
"'spirv.AccessChain' op index ")
307 << index <<
" out of bounds for " << resultType;
311 resultType = cType.getElementType(index);
319 assert(type &&
"Unable to deduce return type based on basePtr and indices");
320 build(builder, state, type, basePtr, indices);
323 template <
typename Op>
325 printer <<
' ' << op.getBasePtr() <<
'[' << indices
326 <<
"] : " << op.getBasePtr().getType() <<
", " << indices.
getTypes();
329 template <
typename Op>
332 indices, accessChainOp.
getLoc());
336 auto providedResultType =
337 llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
338 if (!providedResultType)
340 "result type must be a pointer, but provided")
341 << providedResultType;
343 if (resultType != providedResultType)
344 return accessChainOp.
emitOpError(
"invalid result type: expected ")
345 << resultType <<
", but provided " << providedResultType;
358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359 MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360 auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
361 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
365 ParseResult
LoadOp::parse(OpAsmParser &parser, OperationState &result) {
367 spirv::StorageClass storageClass;
368 OpAsmParser::UnresolvedOperand ptrInfo;
370 if (
parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371 parseMemoryAccessAttributes<LoadOp>(parser, result) ||
372 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373 parser.parseType(elementType)) {
378 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
382 result.addTypes(elementType);
387 SmallVector<StringRef, 4> elidedAttrs;
388 StringRef sc = stringifyStorageClass(
389 llvm::cast<spirv::PointerType>(getPtr().
getType()).getStorageClass());
390 printer <<
" \"" << sc <<
"\" " << getPtr();
394 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
412 ParseResult
StoreOp::parse(OpAsmParser &parser, OperationState &result) {
414 spirv::StorageClass storageClass;
415 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416 auto loc = parser.getCurrentLocation();
419 parser.parseOperandList(operandInfo, 2) ||
420 parseMemoryAccessAttributes<StoreOp>(parser, result) ||
421 parser.parseColon() || parser.parseType(elementType)) {
426 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
434 SmallVector<StringRef, 4> elidedAttrs;
435 StringRef sc = stringifyStorageClass(
436 llvm::cast<spirv::PointerType>(getPtr().
getType()).getStorageClass());
437 printer <<
" \"" << sc <<
"\" " << getPtr() <<
", " << getValue();
441 printer <<
" : " << getValue().getType();
442 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
460 StringRef targetStorageClass = stringifyStorageClass(
461 llvm::cast<spirv::PointerType>(getTarget().
getType()).getStorageClass());
462 printer <<
" \"" << targetStorageClass <<
"\" " << getTarget() <<
", ";
464 StringRef sourceStorageClass = stringifyStorageClass(
465 llvm::cast<spirv::PointerType>(getSource().
getType()).getStorageClass());
466 printer <<
" \"" << sourceStorageClass <<
"\" " << getSource();
468 SmallVector<StringRef, 4> elidedAttrs;
471 getSourceMemoryAccess(),
472 getSourceAlignment());
474 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
477 llvm::cast<spirv::PointerType>(getTarget().
getType()).getPointeeType();
478 printer <<
" : " << pointeeType;
482 spirv::StorageClass targetStorageClass;
483 OpAsmParser::UnresolvedOperand targetPtrInfo;
485 spirv::StorageClass sourceStorageClass;
486 OpAsmParser::UnresolvedOperand sourcePtrInfo;
491 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
493 parser.parseOperand(sourcePtrInfo) ||
494 parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
498 if (!parser.parseOptionalComma()) {
500 if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
505 if (parser.parseColon() || parser.parseType(elementType))
508 if (parser.parseOptionalAttrDict(result.attributes))
514 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
524 llvm::cast<spirv::PointerType>(getTarget().
getType()).getPointeeType();
527 llvm::cast<spirv::PointerType>(getSource().
getType()).getPointeeType();
529 if (targetType != sourceType)
530 return emitOpError(
"both operands must be pointers to the same type");
550 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551 Value basePtr, Value element,
552 ValueRange indices) {
554 assert(type &&
"Unable to deduce return type based on basePtr and indices");
555 build(builder, state, type, basePtr, element, indices);
566 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567 Value basePtr, Value element, ValueRange indices) {
569 assert(type &&
"Unable to deduce return type based on basePtr and indices");
570 build(builder, state, type, basePtr, element, indices);
583 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
584 if (succeeded(parser.parseOptionalKeyword(
"init"))) {
585 initInfo = OpAsmParser::UnresolvedOperand();
586 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587 parser.parseRParen())
597 if (parser.parseColon())
599 auto loc = parser.getCurrentLocation();
600 if (parser.parseType(type))
603 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
605 return parser.emitError(loc,
"expected spirv.ptr type");
606 result.addTypes(ptrType);
610 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
615 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
616 ptrType.getStorageClass());
617 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
623 SmallVector<StringRef, 4> elidedAttrs{
624 spirv::attributeName<spirv::StorageClass>()};
626 if (getNumOperands() != 0)
627 printer <<
" init(" << getInitializer() <<
")";
637 if (getStorageClass() != spirv::StorageClass::Function) {
639 "can only be used to model function-level variables. Use "
640 "spirv.GlobalVariable for module-level variables.");
643 auto pointerType = llvm::cast<spirv::PointerType>(getPointer().
getType());
644 if (getStorageClass() != pointerType.getStorageClass())
646 "storage class must match result pointer's storage class");
648 if (getNumOperands() != 0) {
651 auto *initOp = getOperand(0).getDefiningOp();
652 if (!initOp || !isa<spirv::ConstantOp,
653 spirv::ReferenceOfOp,
654 spirv::AddressOfOp>(initOp))
655 return emitOpError(
"initializer must be the result of a "
656 "constant or spirv.GlobalVariable op");
659 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
661 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
665 for (
auto decoration :
666 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667 spirv::Decoration::BuiltIn}) {
668 if (
auto attr = getDecorationAttr(decoration))
669 return emitOpError(
"cannot have '")
670 << llvm::convertToSnakeFromCamelCase(
671 stringifyDecoration(decoration))
672 <<
"' attribute (only allowed in spirv.GlobalVariable)";
679 auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
680 if (!pointeePtrType) {
681 if (
auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
683 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
687 if (pointeePtrType && pointeePtrType.getStorageClass() ==
688 spirv::StorageClass::PhysicalStorageBuffer) {
690 getDecorationAttr(spirv::Decoration::AliasedPointer) !=
nullptr;
691 bool hasRestrictPtr =
692 getDecorationAttr(spirv::Decoration::RestrictPointer) !=
nullptr;
694 if (!hasAliasedPtr && !hasRestrictPtr)
695 return emitOpError() <<
" with physical buffer pointer must be decorated "
696 "either 'AliasedPointer' or 'RestrictPointer'";
698 if (hasAliasedPtr && hasRestrictPtr)
700 <<
" with physical buffer pointer must have exactly one "
701 "aliasing decoration";
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.