23 template <
typename OpTy>
26 spirv::Scope executionScope;
27 GroupOperation groupOperation;
29 if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
30 executionScope, parser, state,
31 OpTy::getExecutionScopeAttrName(state.name)) ||
32 spirv::parseEnumStrAttr<GroupOperationAttr>(
33 groupOperation, parser, state,
34 OpTy::getGroupOperationAttrName(state.name)) ||
38 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
53 if (clusterSizeInfo) {
55 if (parser.
resolveOperand(*clusterSizeInfo, i32Type, state.operands))
62 template <
typename GroupNonUniformArithmeticOpTy>
69 ->getAttrOfType<spirv::ScopeAttr>(
70 GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
74 << stringifyGroupOperation(
76 ->getAttrOfType<GroupOperationAttr>(
77 GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
87 template <
typename OpTy>
92 OpTy::getExecutionScopeAttrName(groupOp->
getName()))
94 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
96 "execution scope must be 'Workgroup' or 'Subgroup'");
98 GroupOperation operation =
101 OpTy::getGroupOperationAttrName(groupOp->
getName()))
103 if (operation == GroupOperation::ClusteredReduce &&
105 return groupOp->
emitOpError(
"cluster size operand must be provided for "
106 "'ClusteredReduce' group operation");
109 int32_t clusterSize = 0;
114 "cluster size operand must come from a constant op");
116 if (!llvm::isPowerOf2_32(clusterSize))
118 "cluster size operand must be a power of two");
128 spirv::Scope scope = getExecutionScope();
129 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
130 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
132 if (
auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().
getType()))
133 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
134 return emitOpError(
"localid is a vector and can be with only "
135 " 2 or 3 components, actual number is ")
136 << localIdTy.getNumElements();
146 spirv::Scope scope = getExecutionScope();
147 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
148 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
158 spirv::Scope scope = getExecutionScope();
159 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
160 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
170 spirv::Scope scope = getExecutionScope();
171 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
172 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
182 spirv::Scope scope = getExecutionScope();
183 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
184 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
189 if (
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
192 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
193 auto *idOp = getId().getDefiningOp();
194 if (!idOp || !isa<spirv::ConstantOp,
195 spirv::ReferenceOfOp>(idOp))
196 return emitOpError(
"id must be the result of a constant op");
206 template <
typename OpTy>
208 spirv::Scope scope = op.getExecutionScope();
209 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
210 return op.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
212 if (op.getOperands().back().getType().isSignedInteger())
213 return op.emitOpError(
"second operand must be a singless/unsigned integer");
236 spirv::Scope scope = getExecutionScope();
237 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
238 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
248 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*
this);
252 OperationState &result) {
253 return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
258 printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*
this, p);
266 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*
this);
270 OperationState &result) {
271 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
276 printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*
this, p);
284 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*
this);
288 OperationState &result) {
289 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
294 printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*
this, p);
302 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*
this);
306 OperationState &result) {
307 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
312 printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*
this, p);
320 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*
this);
324 OperationState &result) {
325 return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
330 printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*
this, p);
338 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*
this);
342 OperationState &result) {
343 return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
348 printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*
this, p);
356 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*
this);
360 OperationState &result) {
361 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
366 printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*
this, p);
374 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*
this);
378 OperationState &result) {
379 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
384 printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*
this, p);
392 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*
this);
396 OperationState &result) {
397 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
402 printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*
this, p);
410 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*
this);
414 OperationState &result) {
415 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
420 printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*
this, p);
428 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*
this);
432 OperationState &result) {
433 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
438 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*
this, p);
446 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*
this);
450 OperationState &result) {
451 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
456 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*
this, p);
464 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*
this);
468 OperationState &result) {
469 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
474 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*
this, p);
482 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*
this);
486 OperationState &result) {
487 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
492 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*
this, p);
500 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*
this);
504 OperationState &result) {
505 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
510 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*
this, p);
518 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*
this);
522 OperationState &result) {
523 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
528 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*
this, p);
535 template <
typename Op>
537 spirv::Scope scope = op.getExecutionScope();
538 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
539 return op.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseLParen()=0
Parse a ( token.
IntegerType getIntegerType(unsigned width)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kClusterSize[]
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
static LogicalResult verifyGroupOp(Op op)
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.