MLIR 23.0.0git
GroupOps.cpp
Go to the documentation of this file.
1//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the group operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
15
16#include "SPIRVOpUtils.h"
17#include "SPIRVParsingUtils.h"
18
19using namespace mlir::spirv::AttrNames;
20
21namespace mlir::spirv {
22
23template <typename OpTy>
24static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
25 GroupOperation operation =
26 groupOp
27 ->getAttrOfType<GroupOperationAttr>(
28 OpTy::getGroupOperationAttrName(groupOp->getName()))
29 .getValue();
30 if (operation == GroupOperation::ClusteredReduce &&
31 groupOp->getNumOperands() == 1)
32 return groupOp->emitOpError("cluster size operand must be provided for "
33 "'ClusteredReduce' group operation");
34 if (groupOp->getNumOperands() > 1) {
35 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
36 int32_t clusterSize = 0;
37
38 // TODO: support specialization constant here.
39 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
40 return groupOp->emitOpError(
41 "cluster size operand must come from a constant op");
42
43 if (!llvm::isPowerOf2_32(clusterSize))
44 return groupOp->emitOpError(
45 "cluster size operand must be a power of two");
46 }
47 return success();
48}
49
50//===----------------------------------------------------------------------===//
51// spirv.GroupBroadcast
52//===----------------------------------------------------------------------===//
53
54LogicalResult GroupBroadcastOp::verify() {
55 if (auto localIdTy = dyn_cast<VectorType>(getLocalid().getType()))
56 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
57 return emitOpError("localid is a vector and can be with only "
58 " 2 or 3 components, actual number is ")
59 << localIdTy.getNumElements();
60
61 return success();
62}
63
64//===----------------------------------------------------------------------===//
65// spirv.GroupNonUniformBroadcast
66//===----------------------------------------------------------------------===//
67
68LogicalResult GroupNonUniformBroadcastOp::verify() {
69 // SPIR-V spec: "Before version 1.5, Id must come from a
70 // constant instruction.
71 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
72 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
73 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
74
75 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
76 auto *idOp = getId().getDefiningOp();
77 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
78 spirv::ReferenceOfOp>(idOp)) // for spec constant
79 return emitOpError("id must be the result of a constant op");
80 }
81
82 return success();
83}
84
85//===----------------------------------------------------------------------===//
86// spirv.GroupNonUniformShuffle*
87//===----------------------------------------------------------------------===//
88
89template <typename OpTy>
90static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
91 if (op.getOperands().back().getType().isSignedInteger())
92 return op.emitOpError("second operand must be a singless/unsigned integer");
93
94 return success();
95}
96
97LogicalResult GroupNonUniformShuffleOp::verify() {
99}
100LogicalResult GroupNonUniformShuffleDownOp::verify() {
101 return verifyGroupNonUniformShuffleOp(*this);
102}
103LogicalResult GroupNonUniformShuffleUpOp::verify() {
104 return verifyGroupNonUniformShuffleOp(*this);
105}
106LogicalResult GroupNonUniformShuffleXorOp::verify() {
107 return verifyGroupNonUniformShuffleOp(*this);
108}
109
110//===----------------------------------------------------------------------===//
111// spirv.GroupNonUniformFAddOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult GroupNonUniformFAddOp::verify() {
116}
117
118//===----------------------------------------------------------------------===//
119// spirv.GroupNonUniformFMaxOp
120//===----------------------------------------------------------------------===//
121
122LogicalResult GroupNonUniformFMaxOp::verify() {
124}
125
126//===----------------------------------------------------------------------===//
127// spirv.GroupNonUniformFMinOp
128//===----------------------------------------------------------------------===//
129
130LogicalResult GroupNonUniformFMinOp::verify() {
132}
133
134//===----------------------------------------------------------------------===//
135// spirv.GroupNonUniformFMulOp
136//===----------------------------------------------------------------------===//
137
138LogicalResult GroupNonUniformFMulOp::verify() {
140}
141
142//===----------------------------------------------------------------------===//
143// spirv.GroupNonUniformIAddOp
144//===----------------------------------------------------------------------===//
145
146LogicalResult GroupNonUniformIAddOp::verify() {
148}
149
150//===----------------------------------------------------------------------===//
151// spirv.GroupNonUniformIMulOp
152//===----------------------------------------------------------------------===//
153
154LogicalResult GroupNonUniformIMulOp::verify() {
156}
157
158//===----------------------------------------------------------------------===//
159// spirv.GroupNonUniformSMaxOp
160//===----------------------------------------------------------------------===//
161
162LogicalResult GroupNonUniformSMaxOp::verify() {
164}
165
166//===----------------------------------------------------------------------===//
167// spirv.GroupNonUniformSMinOp
168//===----------------------------------------------------------------------===//
169
170LogicalResult GroupNonUniformSMinOp::verify() {
172}
173
174//===----------------------------------------------------------------------===//
175// spirv.GroupNonUniformUMaxOp
176//===----------------------------------------------------------------------===//
177
178LogicalResult GroupNonUniformUMaxOp::verify() {
180}
181
182//===----------------------------------------------------------------------===//
183// spirv.GroupNonUniformUMinOp
184//===----------------------------------------------------------------------===//
185
186LogicalResult GroupNonUniformUMinOp::verify() {
188}
189
190//===----------------------------------------------------------------------===//
191// spirv.GroupNonUniformBitwiseAnd
192//===----------------------------------------------------------------------===//
193
194LogicalResult GroupNonUniformBitwiseAndOp::verify() {
196}
197
198//===----------------------------------------------------------------------===//
199// spirv.GroupNonUniformBitwiseOr
200//===----------------------------------------------------------------------===//
201
202LogicalResult GroupNonUniformBitwiseOrOp::verify() {
204}
205
206//===----------------------------------------------------------------------===//
207// spirv.GroupNonUniformBitwiseXor
208//===----------------------------------------------------------------------===//
209
210LogicalResult GroupNonUniformBitwiseXorOp::verify() {
212}
213
214//===----------------------------------------------------------------------===//
215// spirv.GroupNonUniformLogicalAnd
216//===----------------------------------------------------------------------===//
217
218LogicalResult GroupNonUniformLogicalAndOp::verify() {
220}
221
222//===----------------------------------------------------------------------===//
223// spirv.GroupNonUniformLogicalOr
224//===----------------------------------------------------------------------===//
225
226LogicalResult GroupNonUniformLogicalOrOp::verify() {
228}
229
230//===----------------------------------------------------------------------===//
231// spirv.GroupNonUniformLogicalXor
232//===----------------------------------------------------------------------===//
233
234LogicalResult GroupNonUniformLogicalXorOp::verify() {
236}
237
238//===----------------------------------------------------------------------===//
239// spirv.GroupNonUniformRotateKHR
240//===----------------------------------------------------------------------===//
241
242LogicalResult GroupNonUniformRotateKHROp::verify() {
243 if (Value clusterSizeVal = getClusterSize()) {
244 mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
245 int32_t clusterSize = 0;
246
247 if (failed(extractValueFromConstOp(defOp, clusterSize)))
248 return emitOpError("cluster size operand must come from a constant op");
249
250 if (!llvm::isPowerOf2_32(clusterSize))
251 return emitOpError("cluster size operand must be a power of two");
252 }
253
254 return success();
255}
256
257} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b getContext())
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
unsigned getNumOperands()
Definition Operation.h:372
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition GroupOps.cpp:24
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
Definition GroupOps.cpp:90
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition SPIRVOps.cpp:49
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307