MLIR  20.0.0git
GroupOps.cpp
Go to the documentation of this file.
1 //===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the group operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 #include "SPIRVOpUtils.h"
17 #include "SPIRVParsingUtils.h"
18 
19 using namespace mlir::spirv::AttrNames;
20 
21 namespace mlir::spirv {
22 
23 template <typename OpTy>
24 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
25  spirv::Scope scope =
26  groupOp
27  ->getAttrOfType<spirv::ScopeAttr>(
28  OpTy::getExecutionScopeAttrName(groupOp->getName()))
29  .getValue();
30  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
31  return groupOp->emitOpError(
32  "execution scope must be 'Workgroup' or 'Subgroup'");
33 
34  GroupOperation operation =
35  groupOp
36  ->getAttrOfType<GroupOperationAttr>(
37  OpTy::getGroupOperationAttrName(groupOp->getName()))
38  .getValue();
39  if (operation == GroupOperation::ClusteredReduce &&
40  groupOp->getNumOperands() == 1)
41  return groupOp->emitOpError("cluster size operand must be provided for "
42  "'ClusteredReduce' group operation");
43  if (groupOp->getNumOperands() > 1) {
44  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
45  int32_t clusterSize = 0;
46 
47  // TODO: support specialization constant here.
48  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
49  return groupOp->emitOpError(
50  "cluster size operand must come from a constant op");
51 
52  if (!llvm::isPowerOf2_32(clusterSize))
53  return groupOp->emitOpError(
54  "cluster size operand must be a power of two");
55  }
56  return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // spirv.GroupBroadcast
61 //===----------------------------------------------------------------------===//
62 
63 LogicalResult GroupBroadcastOp::verify() {
64  spirv::Scope scope = getExecutionScope();
65  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
66  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
67 
68  if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
69  if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
70  return emitOpError("localid is a vector and can be with only "
71  " 2 or 3 components, actual number is ")
72  << localIdTy.getNumElements();
73 
74  return success();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // spirv.GroupNonUniformBallotOp
79 //===----------------------------------------------------------------------===//
80 
81 LogicalResult GroupNonUniformBallotOp::verify() {
82  spirv::Scope scope = getExecutionScope();
83  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
84  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
85 
86  return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // spirv.GroupNonUniformBallotFindLSBOp
91 //===----------------------------------------------------------------------===//
92 
94  spirv::Scope scope = getExecutionScope();
95  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
96  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
97 
98  return success();
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // spirv.GroupNonUniformBallotFindLSBOp
103 //===----------------------------------------------------------------------===//
104 
106  spirv::Scope scope = getExecutionScope();
107  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
108  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
109 
110  return success();
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // spirv.GroupNonUniformBroadcast
115 //===----------------------------------------------------------------------===//
116 
117 LogicalResult GroupNonUniformBroadcastOp::verify() {
118  spirv::Scope scope = getExecutionScope();
119  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
120  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
121 
122  // SPIR-V spec: "Before version 1.5, Id must come from a
123  // constant instruction.
124  auto targetEnv = spirv::getDefaultTargetEnv(getContext());
125  if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
126  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
127 
128  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
129  auto *idOp = getId().getDefiningOp();
130  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
131  spirv::ReferenceOfOp>(idOp)) // for spec constant
132  return emitOpError("id must be the result of a constant op");
133  }
134 
135  return success();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // spirv.GroupNonUniformShuffle*
140 //===----------------------------------------------------------------------===//
141 
142 template <typename OpTy>
143 static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
144  spirv::Scope scope = op.getExecutionScope();
145  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
146  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
147 
148  if (op.getOperands().back().getType().isSignedInteger())
149  return op.emitOpError("second operand must be a singless/unsigned integer");
150 
151  return success();
152 }
153 
154 LogicalResult GroupNonUniformShuffleOp::verify() {
155  return verifyGroupNonUniformShuffleOp(*this);
156 }
157 LogicalResult GroupNonUniformShuffleDownOp::verify() {
158  return verifyGroupNonUniformShuffleOp(*this);
159 }
160 LogicalResult GroupNonUniformShuffleUpOp::verify() {
161  return verifyGroupNonUniformShuffleOp(*this);
162 }
163 LogicalResult GroupNonUniformShuffleXorOp::verify() {
164  return verifyGroupNonUniformShuffleOp(*this);
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // spirv.GroupNonUniformElectOp
169 //===----------------------------------------------------------------------===//
170 
171 LogicalResult GroupNonUniformElectOp::verify() {
172  spirv::Scope scope = getExecutionScope();
173  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
174  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
175 
176  return success();
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // spirv.GroupNonUniformFAddOp
181 //===----------------------------------------------------------------------===//
182 
183 LogicalResult GroupNonUniformFAddOp::verify() {
184  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // spirv.GroupNonUniformFMaxOp
189 //===----------------------------------------------------------------------===//
190 
191 LogicalResult GroupNonUniformFMaxOp::verify() {
192  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // spirv.GroupNonUniformFMinOp
197 //===----------------------------------------------------------------------===//
198 
199 LogicalResult GroupNonUniformFMinOp::verify() {
200  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // spirv.GroupNonUniformFMulOp
205 //===----------------------------------------------------------------------===//
206 
207 LogicalResult GroupNonUniformFMulOp::verify() {
208  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // spirv.GroupNonUniformIAddOp
213 //===----------------------------------------------------------------------===//
214 
215 LogicalResult GroupNonUniformIAddOp::verify() {
216  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // spirv.GroupNonUniformIMulOp
221 //===----------------------------------------------------------------------===//
222 
223 LogicalResult GroupNonUniformIMulOp::verify() {
224  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // spirv.GroupNonUniformSMaxOp
229 //===----------------------------------------------------------------------===//
230 
231 LogicalResult GroupNonUniformSMaxOp::verify() {
232  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // spirv.GroupNonUniformSMinOp
237 //===----------------------------------------------------------------------===//
238 
239 LogicalResult GroupNonUniformSMinOp::verify() {
240  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // spirv.GroupNonUniformUMaxOp
245 //===----------------------------------------------------------------------===//
246 
247 LogicalResult GroupNonUniformUMaxOp::verify() {
248  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // spirv.GroupNonUniformUMinOp
253 //===----------------------------------------------------------------------===//
254 
255 LogicalResult GroupNonUniformUMinOp::verify() {
256  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // spirv.GroupNonUniformBitwiseAnd
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult GroupNonUniformBitwiseAndOp::verify() {
264  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // spirv.GroupNonUniformBitwiseOr
269 //===----------------------------------------------------------------------===//
270 
271 LogicalResult GroupNonUniformBitwiseOrOp::verify() {
272  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // spirv.GroupNonUniformBitwiseXor
277 //===----------------------------------------------------------------------===//
278 
279 LogicalResult GroupNonUniformBitwiseXorOp::verify() {
280  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // spirv.GroupNonUniformLogicalAnd
285 //===----------------------------------------------------------------------===//
286 
287 LogicalResult GroupNonUniformLogicalAndOp::verify() {
288  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // spirv.GroupNonUniformLogicalOr
293 //===----------------------------------------------------------------------===//
294 
295 LogicalResult GroupNonUniformLogicalOrOp::verify() {
296  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // spirv.GroupNonUniformLogicalXor
301 //===----------------------------------------------------------------------===//
302 
303 LogicalResult GroupNonUniformLogicalXorOp::verify() {
304  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Group op verification
309 //===----------------------------------------------------------------------===//
310 
311 template <typename Op>
312 static LogicalResult verifyGroupOp(Op op) {
313  spirv::Scope scope = op.getExecutionScope();
314  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
315  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
316 
317  return success();
318 }
319 
320 LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
321 
322 LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
323 
324 LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
325 
326 LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
327 
328 LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
329 
330 LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
331 
332 LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
333 
334 LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
335 
336 LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
337 
338 LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
339 
340 } // namespace mlir::spirv
static MLIRContext * getContext(OpFoldResult val)
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
unsigned getNumOperands()
Definition: Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: GroupOps.cpp:24
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
Definition: GroupOps.cpp:143
static LogicalResult verifyGroupOp(Op op)
Definition: GroupOps.cpp:312
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425