23 template <
typename OpTy>
26 spirv::Scope executionScope;
27 GroupOperation groupOperation;
29 if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
30 executionScope, parser, state,
31 OpTy::getExecutionScopeAttrName(state.name)) ||
32 spirv::parseEnumStrAttr<GroupOperationAttr>(
33 groupOperation, parser, state,
34 OpTy::getGroupOperationAttrName(state.name)) ||
38 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
53 if (clusterSizeInfo) {
55 if (parser.
resolveOperand(*clusterSizeInfo, i32Type, state.operands))
62 template <
typename GroupNonUniformArithmeticOpTy>
69 ->getAttrOfType<spirv::ScopeAttr>(
70 GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
74 << stringifyGroupOperation(
76 ->getAttrOfType<GroupOperationAttr>(
77 GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
87 template <
typename OpTy>
92 OpTy::getExecutionScopeAttrName(groupOp->
getName()))
94 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
96 "execution scope must be 'Workgroup' or 'Subgroup'");
98 GroupOperation operation =
101 OpTy::getGroupOperationAttrName(groupOp->
getName()))
103 if (operation == GroupOperation::ClusteredReduce &&
105 return groupOp->
emitOpError(
"cluster size operand must be provided for "
106 "'ClusteredReduce' group operation");
109 int32_t clusterSize = 0;
114 "cluster size operand must come from a constant op");
116 if (!llvm::isPowerOf2_32(clusterSize))
118 "cluster size operand must be a power of two");
128 spirv::Scope scope = getExecutionScope();
129 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
130 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
132 if (
auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
133 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
134 return emitOpError(
"localid is a vector and can be with only "
135 " 2 or 3 components, actual number is ")
136 << localIdTy.getNumElements();
146 spirv::Scope scope = getExecutionScope();
147 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
148 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
158 spirv::Scope scope = getExecutionScope();
159 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
160 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
165 if (
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
168 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
169 auto *idOp = getId().getDefiningOp();
170 if (!idOp || !isa<spirv::ConstantOp,
171 spirv::ReferenceOfOp>(idOp))
172 return emitOpError(
"id must be the result of a constant op");
182 template <
typename OpTy>
184 spirv::Scope scope = op.getExecutionScope();
185 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
186 return op.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
188 if (op.getOperands().back().getType().isSignedInteger())
189 return op.emitOpError(
"second operand must be a singless/unsigned integer");
212 spirv::Scope scope = getExecutionScope();
213 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
214 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
224 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*
this);
228 OperationState &result) {
229 return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
234 printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*
this, p);
242 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*
this);
246 OperationState &result) {
247 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
252 printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*
this, p);
260 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*
this);
264 OperationState &result) {
265 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
270 printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*
this, p);
278 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*
this);
282 OperationState &result) {
283 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
288 printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*
this, p);
296 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*
this);
300 OperationState &result) {
301 return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
306 printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*
this, p);
314 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*
this);
318 OperationState &result) {
319 return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
324 printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*
this, p);
332 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*
this);
336 OperationState &result) {
337 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
342 printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*
this, p);
350 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*
this);
354 OperationState &result) {
355 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
360 printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*
this, p);
368 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*
this);
372 OperationState &result) {
373 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
378 printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*
this, p);
386 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*
this);
390 OperationState &result) {
391 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
396 printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*
this, p);
404 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*
this);
408 OperationState &result) {
409 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
414 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*
this, p);
422 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*
this);
426 OperationState &result) {
427 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
432 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*
this, p);
440 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*
this);
444 OperationState &result) {
445 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
450 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*
this, p);
458 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*
this);
462 OperationState &result) {
463 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
468 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*
this, p);
476 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*
this);
480 OperationState &result) {
481 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
486 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*
this, p);
494 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*
this);
498 OperationState &result) {
499 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
504 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*
this, p);
511 template <
typename Op>
513 spirv::Scope scope = op.getExecutionScope();
514 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
515 return op.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseLParen()=0
Parse a ( token.
IntegerType getIntegerType(unsigned width)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kClusterSize[]
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
static LogicalResult verifyGroupOp(Op op)
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.