21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
34 template <
typename MemoryOpTy>
43 spirv::MemoryAccess memoryAccessAttr;
44 StringAttr memoryAccessAttrName =
45 MemoryOpTy::getMemoryAccessAttrName(state.name);
46 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
47 memoryAccessAttr, parser, state, memoryAccessAttrName))
50 if (spirv::bitEnumContainsAll(memoryAccessAttr,
51 spirv::MemoryAccess::Aligned)) {
54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
69 template <
typename MemoryOpTy>
78 spirv::MemoryAccess memoryAccessAttr;
79 StringRef memoryAccessAttrName =
80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81 if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
82 memoryAccessAttr, parser, state, memoryAccessAttrName))
85 if (spirv::bitEnumContainsAll(memoryAccessAttr,
86 spirv::MemoryAccess::Aligned)) {
89 StringAttr alignmentAttrName =
90 MemoryOpTy::getSourceAlignmentAttrName(state.name);
105 template <
typename MemoryOpTy>
109 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
115 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116 : memoryOp.getMemoryAccess())) {
117 elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
119 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
123 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
124 : memoryOp.getAlignment())) {
125 elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126 printer <<
", " << *alignment;
131 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
134 template <
typename MemoryOpTy>
138 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
141 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142 : memoryOp.getMemoryAccess())) {
143 elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
145 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
149 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
150 : memoryOp.getAlignment())) {
151 elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152 printer <<
", " << *alignment;
157 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
160 template <
typename LoadStoreOpTy>
169 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
170 return op.emitOpError(
"mismatch in result type and pointer type");
175 template <
typename MemoryOpTy>
180 auto *op = memoryOp.getOperation();
181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182 if (!memAccessAttr) {
185 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186 return memoryOp.emitOpError(
187 "invalid alignment specification without aligned memory access "
193 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
196 return memoryOp.emitOpError(
"invalid memory access specifier: ")
200 if (spirv::bitEnumContainsAll(memAccess.getValue(),
201 spirv::MemoryAccess::Aligned)) {
202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203 return memoryOp.emitOpError(
"missing alignment value");
206 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207 return memoryOp.emitOpError(
208 "invalid alignment specification with non-aligned memory access "
219 template <
typename MemoryOpTy>
224 auto *op = memoryOp.getOperation();
225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226 if (!memAccessAttr) {
229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230 return memoryOp.emitOpError(
231 "invalid alignment specification without aligned memory access "
237 auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
240 return memoryOp.emitOpError(
"invalid memory access specifier: ")
244 if (spirv::bitEnumContainsAll(memAccess.getValue(),
245 spirv::MemoryAccess::Aligned)) {
246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247 return memoryOp.emitOpError(
"missing alignment value");
250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251 return memoryOp.emitOpError(
252 "invalid alignment specification with non-aligned memory access "
264 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
266 emitError(baseLoc,
"'spirv.AccessChain' op expected a pointer "
267 "to composite type, but provided ")
272 auto resultType = ptrType.getPointeeType();
273 auto resultStorageClass = ptrType.getStorageClass();
276 for (
auto indexSSA : indices) {
277 auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
281 "'spirv.AccessChain' op cannot extract from non-composite type ")
282 << resultType <<
" with index " << index;
286 if (llvm::isa<spirv::StructType>(resultType)) {
287 Operation *op = indexSSA.getDefiningOp();
289 emitError(baseLoc,
"'spirv.AccessChain' op index must be an "
290 "integer spirv.Constant to access "
291 "element of spirv.struct");
300 "'spirv.AccessChain' index must be an integer spirv.Constant to "
301 "access element of spirv.struct, but provided ")
305 if (index < 0 ||
static_cast<uint64_t
>(index) >= cType.getNumElements()) {
306 emitError(baseLoc,
"'spirv.AccessChain' op index ")
307 << index <<
" out of bounds for " << resultType;
311 resultType = cType.getElementType(index);
319 assert(type &&
"Unable to deduce return type based on basePtr and indices");
320 build(builder, state, type, basePtr, indices);
324 OpAsmParser::UnresolvedOperand ptrInfo;
325 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
327 auto loc = parser.getCurrentLocation();
328 SmallVector<Type, 4> indicesTypes;
330 if (parser.parseOperand(ptrInfo) ||
332 parser.parseColonType(type) ||
333 parser.resolveOperand(ptrInfo, type, result.operands)) {
339 if (indicesInfo.empty()) {
341 "'spirv.AccessChain' op expected at "
345 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
350 if (indicesTypes.size() != indicesInfo.size()) {
352 result.location,
"'spirv.AccessChain' op indices types' count must be "
353 "equal to indices info count");
356 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
360 type,
llvm::ArrayRef(result.operands).drop_front(), result.location);
365 result.addTypes(resultType);
369 template <
typename Op>
371 printer <<
' ' << op.getBasePtr() <<
'[' << indices
372 <<
"] : " << op.getBasePtr().getType() <<
", " << indices.
getTypes();
379 template <
typename Op>
382 indices, accessChainOp.
getLoc());
386 auto providedResultType =
387 llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
388 if (!providedResultType)
390 "result type must be a pointer, but provided")
391 << providedResultType;
393 if (resultType != providedResultType)
394 return accessChainOp.
emitOpError(
"invalid result type: expected ")
395 << resultType <<
", but provided " << providedResultType;
408 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
409 MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
410 auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
411 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
415 ParseResult
LoadOp::parse(OpAsmParser &parser, OperationState &result) {
417 spirv::StorageClass storageClass;
418 OpAsmParser::UnresolvedOperand ptrInfo;
420 if (
parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
421 parseMemoryAccessAttributes<LoadOp>(parser, result) ||
422 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
423 parser.parseType(elementType)) {
428 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
432 result.addTypes(elementType);
437 SmallVector<StringRef, 4> elidedAttrs;
438 StringRef sc = stringifyStorageClass(
439 llvm::cast<spirv::PointerType>(getPtr().
getType()).getStorageClass());
440 printer <<
" \"" << sc <<
"\" " << getPtr();
444 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
462 ParseResult
StoreOp::parse(OpAsmParser &parser, OperationState &result) {
464 spirv::StorageClass storageClass;
465 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
466 auto loc = parser.getCurrentLocation();
469 parser.parseOperandList(operandInfo, 2) ||
470 parseMemoryAccessAttributes<StoreOp>(parser, result) ||
471 parser.parseColon() || parser.parseType(elementType)) {
476 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
484 SmallVector<StringRef, 4> elidedAttrs;
485 StringRef sc = stringifyStorageClass(
486 llvm::cast<spirv::PointerType>(getPtr().
getType()).getStorageClass());
487 printer <<
" \"" << sc <<
"\" " << getPtr() <<
", " << getValue();
491 printer <<
" : " << getValue().getType();
492 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
510 StringRef targetStorageClass = stringifyStorageClass(
511 llvm::cast<spirv::PointerType>(getTarget().
getType()).getStorageClass());
512 printer <<
" \"" << targetStorageClass <<
"\" " << getTarget() <<
", ";
514 StringRef sourceStorageClass = stringifyStorageClass(
515 llvm::cast<spirv::PointerType>(getSource().
getType()).getStorageClass());
516 printer <<
" \"" << sourceStorageClass <<
"\" " << getSource();
518 SmallVector<StringRef, 4> elidedAttrs;
521 getSourceMemoryAccess(),
522 getSourceAlignment());
524 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
527 llvm::cast<spirv::PointerType>(getTarget().
getType()).getPointeeType();
528 printer <<
" : " << pointeeType;
532 spirv::StorageClass targetStorageClass;
533 OpAsmParser::UnresolvedOperand targetPtrInfo;
535 spirv::StorageClass sourceStorageClass;
536 OpAsmParser::UnresolvedOperand sourcePtrInfo;
541 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
543 parser.parseOperand(sourcePtrInfo) ||
544 parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
548 if (!parser.parseOptionalComma()) {
550 if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
555 if (parser.parseColon() || parser.parseType(elementType))
558 if (parser.parseOptionalAttrDict(result.attributes))
564 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
565 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
574 llvm::cast<spirv::PointerType>(getTarget().
getType()).getPointeeType();
577 llvm::cast<spirv::PointerType>(getSource().
getType()).getPointeeType();
579 if (targetType != sourceType)
580 return emitOpError(
"both operands must be pointers to the same type");
613 if (indicesInfo.empty())
614 return emitError(state.location) << opName <<
" expected element";
621 if (indicesTypes.size() != indicesInfo.size())
624 <<
" indices types' count must be equal to indices info count";
626 if (parser.
resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
630 type,
llvm::ArrayRef(state.operands).drop_front(2), state.location);
634 state.addTypes(resultType);
638 template <
typename Op>
641 ret[0] = op.getElement();
654 assert(type &&
"Unable to deduce return type based on basePtr and indices");
655 build(builder, state, type, basePtr, element, indices);
659 OperationState &result) {
661 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
676 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
677 Value basePtr, Value element, ValueRange indices) {
679 assert(type &&
"Unable to deduce return type based on basePtr and indices");
680 build(builder, state, type, basePtr, element, indices);
684 OperationState &result) {
703 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
704 if (succeeded(parser.parseOptionalKeyword(
"init"))) {
705 initInfo = OpAsmParser::UnresolvedOperand();
706 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
707 parser.parseRParen())
717 if (parser.parseColon())
719 auto loc = parser.getCurrentLocation();
720 if (parser.parseType(type))
723 auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
725 return parser.emitError(loc,
"expected spirv.ptr type");
726 result.addTypes(ptrType);
730 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
735 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
736 ptrType.getStorageClass());
737 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
743 SmallVector<StringRef, 4> elidedAttrs{
744 spirv::attributeName<spirv::StorageClass>()};
746 if (getNumOperands() != 0)
747 printer <<
" init(" << getInitializer() <<
")";
757 if (getStorageClass() != spirv::StorageClass::Function) {
759 "can only be used to model function-level variables. Use "
760 "spirv.GlobalVariable for module-level variables.");
763 auto pointerType = llvm::cast<spirv::PointerType>(getPointer().
getType());
764 if (getStorageClass() != pointerType.getStorageClass())
766 "storage class must match result pointer's storage class");
768 if (getNumOperands() != 0) {
771 auto *initOp = getOperand(0).getDefiningOp();
772 if (!initOp || !isa<spirv::ConstantOp,
773 spirv::ReferenceOfOp,
774 spirv::AddressOfOp>(initOp))
775 return emitOpError(
"initializer must be the result of a "
776 "constant or spirv.GlobalVariable op");
779 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
781 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
785 for (
auto decoration :
786 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
787 spirv::Decoration::BuiltIn}) {
788 if (
auto attr = getDecorationAttr(decoration))
789 return emitOpError(
"cannot have '")
790 << llvm::convertToSnakeFromCamelCase(
791 stringifyDecoration(decoration))
792 <<
"' attribute (only allowed in spirv.GlobalVariable)";
799 auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
800 if (!pointeePtrType) {
801 if (
auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
803 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
807 if (pointeePtrType && pointeePtrType.getStorageClass() ==
808 spirv::StorageClass::PhysicalStorageBuffer) {
810 getDecorationAttr(spirv::Decoration::AliasedPointer) !=
nullptr;
811 bool hasRestrictPtr =
812 getDecorationAttr(spirv::Decoration::RestrictPointer) !=
nullptr;
814 if (!hasAliasedPtr && !hasRestrictPtr)
815 return emitOpError() <<
" with physical buffer pointer must be decorated "
816 "either 'AliasedPointer' or 'RestrictPointer'";
818 if (hasAliasedPtr && hasRestrictPtr)
820 <<
" with physical buffer pointer must have exactly one "
821 "aliasing decoration";
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
static auto concatElemAndIndices(Op op)
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.