MLIR  18.0.0git
GroupOps.cpp
Go to the documentation of this file.
1 //===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the group operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
18 
19 using namespace mlir::spirv::AttrNames;
20 
21 namespace mlir::spirv {
22 
24  OperationState &state) {
25  spirv::Scope executionScope;
26  GroupOperation groupOperation;
28  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
30  spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
32  parser.parseOperand(valueInfo))
33  return failure();
34 
35  std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
37  clusterSizeInfo = OpAsmParser::UnresolvedOperand();
38  if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
39  parser.parseRParen())
40  return failure();
41  }
42 
43  Type resultType;
44  if (parser.parseColonType(resultType))
45  return failure();
46 
47  if (parser.resolveOperand(valueInfo, resultType, state.operands))
48  return failure();
49 
50  if (clusterSizeInfo) {
51  Type i32Type = parser.getBuilder().getIntegerType(32);
52  if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
53  return failure();
54  }
55 
56  return parser.addTypeToList(resultType, state.types);
57 }
58 
60  OpAsmPrinter &printer) {
61  printer
62  << " \""
63  << stringifyScope(
64  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
65  .getValue())
66  << "\" \""
67  << stringifyGroupOperation(
68  groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
69  .getValue())
70  << "\" " << groupOp->getOperand(0);
71 
72  if (groupOp->getNumOperands() > 1)
73  printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
74  printer << " : " << groupOp->getResult(0).getType();
75 }
76 
78  spirv::Scope scope =
79  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
80  .getValue();
81  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
82  return groupOp->emitOpError(
83  "execution scope must be 'Workgroup' or 'Subgroup'");
84 
85  GroupOperation operation =
86  groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
87  .getValue();
88  if (operation == GroupOperation::ClusteredReduce &&
89  groupOp->getNumOperands() == 1)
90  return groupOp->emitOpError("cluster size operand must be provided for "
91  "'ClusteredReduce' group operation");
92  if (groupOp->getNumOperands() > 1) {
93  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
94  int32_t clusterSize = 0;
95 
96  // TODO: support specialization constant here.
97  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
98  return groupOp->emitOpError(
99  "cluster size operand must come from a constant op");
100 
101  if (!llvm::isPowerOf2_32(clusterSize))
102  return groupOp->emitOpError(
103  "cluster size operand must be a power of two");
104  }
105  return success();
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // spirv.GroupBroadcast
110 //===----------------------------------------------------------------------===//
111 
113  spirv::Scope scope = getExecutionScope();
114  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
115  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
116 
117  if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
118  if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
119  return emitOpError("localid is a vector and can be with only "
120  " 2 or 3 components, actual number is ")
121  << localIdTy.getNumElements();
122 
123  return success();
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // spirv.GroupNonUniformBallotOp
128 //===----------------------------------------------------------------------===//
129 
130 LogicalResult GroupNonUniformBallotOp::verify() {
131  spirv::Scope scope = getExecutionScope();
132  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
133  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
134 
135  return success();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // spirv.GroupNonUniformBroadcast
140 //===----------------------------------------------------------------------===//
141 
142 LogicalResult GroupNonUniformBroadcastOp::verify() {
143  spirv::Scope scope = getExecutionScope();
144  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
145  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
146 
147  // SPIR-V spec: "Before version 1.5, Id must come from a
148  // constant instruction.
149  auto targetEnv = spirv::getDefaultTargetEnv(getContext());
150  if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
151  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
152 
153  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
154  auto *idOp = getId().getDefiningOp();
155  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
156  spirv::ReferenceOfOp>(idOp)) // for spec constant
157  return emitOpError("id must be the result of a constant op");
158  }
159 
160  return success();
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // spirv.GroupNonUniformShuffle*
165 //===----------------------------------------------------------------------===//
166 
167 template <typename OpTy>
169  spirv::Scope scope = op.getExecutionScope();
170  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
171  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
172 
173  if (op.getOperands().back().getType().isSignedInteger())
174  return op.emitOpError("second operand must be a singless/unsigned integer");
175 
176  return success();
177 }
178 
180  return verifyGroupNonUniformShuffleOp(*this);
181 }
182 LogicalResult GroupNonUniformShuffleDownOp::verify() {
183  return verifyGroupNonUniformShuffleOp(*this);
184 }
185 LogicalResult GroupNonUniformShuffleUpOp::verify() {
186  return verifyGroupNonUniformShuffleOp(*this);
187 }
188 LogicalResult GroupNonUniformShuffleXorOp::verify() {
189  return verifyGroupNonUniformShuffleOp(*this);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // spirv.GroupNonUniformElectOp
194 //===----------------------------------------------------------------------===//
195 
196 LogicalResult GroupNonUniformElectOp::verify() {
197  spirv::Scope scope = getExecutionScope();
198  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
199  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
200 
201  return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // spirv.GroupNonUniformFAddOp
206 //===----------------------------------------------------------------------===//
207 
208 LogicalResult GroupNonUniformFAddOp::verify() {
209  return verifyGroupNonUniformArithmeticOp(*this);
210 }
211 
212 ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
213  OperationState &result) {
214  return parseGroupNonUniformArithmeticOp(parser, result);
215 }
216 
217 void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // spirv.GroupNonUniformFMaxOp
223 //===----------------------------------------------------------------------===//
224 
225 LogicalResult GroupNonUniformFMaxOp::verify() {
226  return verifyGroupNonUniformArithmeticOp(*this);
227 }
228 
229 ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
230  OperationState &result) {
231  return parseGroupNonUniformArithmeticOp(parser, result);
232 }
233 
234 void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // spirv.GroupNonUniformFMinOp
240 //===----------------------------------------------------------------------===//
241 
242 LogicalResult GroupNonUniformFMinOp::verify() {
243  return verifyGroupNonUniformArithmeticOp(*this);
244 }
245 
246 ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
247  OperationState &result) {
248  return parseGroupNonUniformArithmeticOp(parser, result);
249 }
250 
251 void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // spirv.GroupNonUniformFMulOp
257 //===----------------------------------------------------------------------===//
258 
259 LogicalResult GroupNonUniformFMulOp::verify() {
260  return verifyGroupNonUniformArithmeticOp(*this);
261 }
262 
263 ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
264  OperationState &result) {
265  return parseGroupNonUniformArithmeticOp(parser, result);
266 }
267 
268 void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // spirv.GroupNonUniformIAddOp
274 //===----------------------------------------------------------------------===//
275 
276 LogicalResult GroupNonUniformIAddOp::verify() {
277  return verifyGroupNonUniformArithmeticOp(*this);
278 }
279 
280 ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
281  OperationState &result) {
282  return parseGroupNonUniformArithmeticOp(parser, result);
283 }
284 
285 void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // spirv.GroupNonUniformIMulOp
291 //===----------------------------------------------------------------------===//
292 
293 LogicalResult GroupNonUniformIMulOp::verify() {
294  return verifyGroupNonUniformArithmeticOp(*this);
295 }
296 
297 ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
298  OperationState &result) {
299  return parseGroupNonUniformArithmeticOp(parser, result);
300 }
301 
302 void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // spirv.GroupNonUniformSMaxOp
308 //===----------------------------------------------------------------------===//
309 
310 LogicalResult GroupNonUniformSMaxOp::verify() {
311  return verifyGroupNonUniformArithmeticOp(*this);
312 }
313 
314 ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
315  OperationState &result) {
316  return parseGroupNonUniformArithmeticOp(parser, result);
317 }
318 
319 void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // spirv.GroupNonUniformSMinOp
325 //===----------------------------------------------------------------------===//
326 
327 LogicalResult GroupNonUniformSMinOp::verify() {
328  return verifyGroupNonUniformArithmeticOp(*this);
329 }
330 
331 ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
332  OperationState &result) {
333  return parseGroupNonUniformArithmeticOp(parser, result);
334 }
335 
336 void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // spirv.GroupNonUniformUMaxOp
342 //===----------------------------------------------------------------------===//
343 
344 LogicalResult GroupNonUniformUMaxOp::verify() {
345  return verifyGroupNonUniformArithmeticOp(*this);
346 }
347 
348 ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
349  OperationState &result) {
350  return parseGroupNonUniformArithmeticOp(parser, result);
351 }
352 
353 void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // spirv.GroupNonUniformUMinOp
359 //===----------------------------------------------------------------------===//
360 
361 LogicalResult GroupNonUniformUMinOp::verify() {
362  return verifyGroupNonUniformArithmeticOp(*this);
363 }
364 
365 ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
366  OperationState &result) {
367  return parseGroupNonUniformArithmeticOp(parser, result);
368 }
369 
370 void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // spirv.GroupNonUniformBitwiseAnd
376 //===----------------------------------------------------------------------===//
377 
378 LogicalResult GroupNonUniformBitwiseAndOp::verify() {
379  return verifyGroupNonUniformArithmeticOp(*this);
380 }
381 
382 ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
383  OperationState &result) {
384  return parseGroupNonUniformArithmeticOp(parser, result);
385 }
386 
387 void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // spirv.GroupNonUniformBitwiseOr
393 //===----------------------------------------------------------------------===//
394 
395 LogicalResult GroupNonUniformBitwiseOrOp::verify() {
396  return verifyGroupNonUniformArithmeticOp(*this);
397 }
398 
399 ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
400  OperationState &result) {
401  return parseGroupNonUniformArithmeticOp(parser, result);
402 }
403 
404 void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // spirv.GroupNonUniformBitwiseXor
410 //===----------------------------------------------------------------------===//
411 
412 LogicalResult GroupNonUniformBitwiseXorOp::verify() {
413  return verifyGroupNonUniformArithmeticOp(*this);
414 }
415 
416 ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
417  OperationState &result) {
418  return parseGroupNonUniformArithmeticOp(parser, result);
419 }
420 
421 void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // spirv.GroupNonUniformLogicalAnd
427 //===----------------------------------------------------------------------===//
428 
429 LogicalResult GroupNonUniformLogicalAndOp::verify() {
430  return verifyGroupNonUniformArithmeticOp(*this);
431 }
432 
433 ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
434  OperationState &result) {
435  return parseGroupNonUniformArithmeticOp(parser, result);
436 }
437 
438 void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // spirv.GroupNonUniformLogicalOr
444 //===----------------------------------------------------------------------===//
445 
446 LogicalResult GroupNonUniformLogicalOrOp::verify() {
447  return verifyGroupNonUniformArithmeticOp(*this);
448 }
449 
450 ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
451  OperationState &result) {
452  return parseGroupNonUniformArithmeticOp(parser, result);
453 }
454 
455 void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // spirv.GroupNonUniformLogicalXor
461 //===----------------------------------------------------------------------===//
462 
463 LogicalResult GroupNonUniformLogicalXorOp::verify() {
464  return verifyGroupNonUniformArithmeticOp(*this);
465 }
466 
467 ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
468  OperationState &result) {
469  return parseGroupNonUniformArithmeticOp(parser, result);
470 }
471 
472 void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Group op verification
478 //===----------------------------------------------------------------------===//
479 
480 template <typename Op>
482  spirv::Scope scope = op.getExecutionScope();
483  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
484  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
485 
486  return success();
487 }
488 
490 
491 LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
492 
493 LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
494 
495 LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
496 
497 LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
498 
499 LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
500 
501 LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
502 
503 LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
504 
505 LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
506 
507 LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
508 
509 } // namespace mlir::spirv
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseLParen()=0
Parse a ( token.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class represents success/failure for parsing-like operations that find it important to chain tog...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
constexpr char kExecutionScopeAttrName[]
constexpr char kGroupOperationAttrName[]
constexpr char kClusterSize[]
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
Definition: GroupOps.cpp:59
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Definition: GroupOps.cpp:23
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: GroupOps.cpp:77
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
Definition: GroupOps.cpp:168
static LogicalResult verifyGroupOp(Op op)
Definition: GroupOps.cpp:481
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.