MLIR  18.0.0git
AtomicOps.cpp
Go to the documentation of this file.
1 //===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the atomic operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 using namespace mlir::spirv::AttrNames;
19 
20 namespace mlir::spirv {
21 
22 // Parses an atomic update op. If the update op does not take a value (like
23 // AtomicIIncrement) `hasValue` must be false.
25  OperationState &state, bool hasValue) {
26  spirv::Scope scope;
27  spirv::MemorySemantics memoryScope;
29  OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
30  Type type;
31  SMLoc loc;
32  if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
34  parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
36  parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
37  parser.getCurrentLocation(&loc) || parser.parseColonType(type))
38  return failure();
39 
40  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
41  if (!ptrType)
42  return parser.emitError(loc, "expected pointer type");
43 
44  SmallVector<Type, 2> operandTypes;
45  operandTypes.push_back(ptrType);
46  if (hasValue)
47  operandTypes.push_back(ptrType.getPointeeType());
48  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
49  state.operands))
50  return failure();
51  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
52 }
53 
54 // Prints an atomic update op.
55 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
56  printer << " \"";
57  auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
58  printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
59  auto memorySemanticsAttr =
60  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
61  printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
62  << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
63 }
64 
65 template <typename T>
66 static StringRef stringifyTypeName();
67 
68 template <>
70  return "integer";
71 }
72 
73 template <>
75  return "float";
76 }
77 
78 // Verifies an atomic update op.
79 template <typename ExpectedElementType>
81  auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
82  auto elementType = ptrType.getPointeeType();
83  if (!llvm::isa<ExpectedElementType>(elementType))
84  return op->emitOpError() << "pointer operand must point to an "
85  << stringifyTypeName<ExpectedElementType>()
86  << " value, found " << elementType;
87 
88  if (op->getNumOperands() > 1) {
89  auto valueType = op->getOperand(1).getType();
90  if (valueType != elementType)
91  return op->emitOpError("expected value to have the same type as the "
92  "pointer operand's pointee type ")
93  << elementType << ", but found " << valueType;
94  }
95  auto memorySemantics =
96  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
97  .getValue();
98  if (failed(verifyMemorySemantics(op, memorySemantics))) {
99  return failure();
100  }
101  return success();
102 }
103 
104 template <typename T>
105 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
106  printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
107  << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
108  << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
109  << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
110 }
111 
113  OperationState &state) {
114  spirv::Scope memoryScope;
115  spirv::MemorySemantics equalSemantics, unequalSemantics;
117  Type type;
118  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
120  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
121  equalSemantics, parser, state, kEqualSemanticsAttrName) ||
122  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
123  unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
124  parser.parseOperandList(operandInfo, 3))
125  return failure();
126 
127  auto loc = parser.getCurrentLocation();
128  if (parser.parseColonType(type))
129  return failure();
130 
131  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
132  if (!ptrType)
133  return parser.emitError(loc, "expected pointer type");
134 
135  if (parser.resolveOperands(
136  operandInfo,
137  {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
138  parser.getNameLoc(), state.operands))
139  return failure();
140 
141  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
142 }
143 
144 template <typename T>
146  // According to the spec:
147  // "The type of Value must be the same as Result Type. The type of the value
148  // pointed to by Pointer must be the same as Result Type. This type must also
149  // match the type of Comparator."
150  if (atomOp.getType() != atomOp.getValue().getType())
151  return atomOp.emitOpError("value operand must have the same type as the op "
152  "result, but found ")
153  << atomOp.getValue().getType() << " vs " << atomOp.getType();
154 
155  if (atomOp.getType() != atomOp.getComparator().getType())
156  return atomOp.emitOpError(
157  "comparator operand must have the same type as the op "
158  "result, but found ")
159  << atomOp.getComparator().getType() << " vs " << atomOp.getType();
160 
161  Type pointeeType =
162  llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
163  .getPointeeType();
164  if (atomOp.getType() != pointeeType)
165  return atomOp.emitOpError(
166  "pointer operand's pointee type must have the same "
167  "as the op result type, but found ")
168  << pointeeType << " vs " << atomOp.getType();
169 
170  // TODO: Unequal cannot be set to Release or Acquire and Release.
171  // In addition, Unequal cannot be set to a stronger memory-order then Equal.
172 
173  return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // spirv.AtomicAndOp
178 //===----------------------------------------------------------------------===//
179 
181  return verifyAtomicUpdateOp<IntegerType>(getOperation());
182 }
183 
184 ParseResult AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) {
185  return parseAtomicUpdateOp(parser, result, true);
186 }
187 
188 void AtomicAndOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
189 
190 //===----------------------------------------------------------------------===//
191 // spirv.AtomicCompareExchangeOp
192 //===----------------------------------------------------------------------===//
193 
194 LogicalResult AtomicCompareExchangeOp::verify() {
195  return verifyAtomicCompareExchangeImpl(*this);
196 }
197 
198 ParseResult AtomicCompareExchangeOp::parse(OpAsmParser &parser,
199  OperationState &result) {
200  return parseAtomicCompareExchangeImpl(parser, result);
201 }
202 
203 void AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // spirv.AtomicCompareExchangeWeakOp
209 //===----------------------------------------------------------------------===//
210 
211 LogicalResult AtomicCompareExchangeWeakOp::verify() {
212  return verifyAtomicCompareExchangeImpl(*this);
213 }
214 
215 ParseResult AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
216  OperationState &result) {
217  return parseAtomicCompareExchangeImpl(parser, result);
218 }
219 
220 void AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // spirv.AtomicExchange
226 //===----------------------------------------------------------------------===//
227 
228 void AtomicExchangeOp::print(OpAsmPrinter &printer) {
229  printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
230  << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
231  << " : " << getPointer().getType();
232 }
233 
234 ParseResult AtomicExchangeOp::parse(OpAsmParser &parser,
235  OperationState &result) {
236  spirv::Scope memoryScope;
237  spirv::MemorySemantics semantics;
238  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
239  Type type;
240  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
242  parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
244  parser.parseOperandList(operandInfo, 2))
245  return failure();
246 
247  auto loc = parser.getCurrentLocation();
248  if (parser.parseColonType(type))
249  return failure();
250 
251  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
252  if (!ptrType)
253  return parser.emitError(loc, "expected pointer type");
254 
255  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
256  parser.getNameLoc(), result.operands))
257  return failure();
258 
259  return parser.addTypeToList(ptrType.getPointeeType(), result.types);
260 }
261 
262 LogicalResult AtomicExchangeOp::verify() {
263  if (getType() != getValue().getType())
264  return emitOpError("value operand must have the same type as the op "
265  "result, but found ")
266  << getValue().getType() << " vs " << getType();
267 
268  Type pointeeType =
269  llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
270  if (getType() != pointeeType)
271  return emitOpError("pointer operand's pointee type must have the same "
272  "as the op result type, but found ")
273  << pointeeType << " vs " << getType();
274 
275  return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // spirv.AtomicIAddOp
280 //===----------------------------------------------------------------------===//
281 
282 LogicalResult AtomicIAddOp::verify() {
283  return verifyAtomicUpdateOp<IntegerType>(getOperation());
284 }
285 
286 ParseResult AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) {
287  return parseAtomicUpdateOp(parser, result, true);
288 }
289 
290 void AtomicIAddOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
291 
292 //===----------------------------------------------------------------------===//
293 // spirv.EXT.AtomicFAddOp
294 //===----------------------------------------------------------------------===//
295 
296 LogicalResult EXTAtomicFAddOp::verify() {
297  return verifyAtomicUpdateOp<FloatType>(getOperation());
298 }
299 
300 ParseResult EXTAtomicFAddOp::parse(OpAsmParser &parser,
301  OperationState &result) {
302  return parseAtomicUpdateOp(parser, result, true);
303 }
304 
305 void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
306  printAtomicUpdateOp(*this, p);
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // spirv.AtomicIDecrementOp
311 //===----------------------------------------------------------------------===//
312 
313 LogicalResult AtomicIDecrementOp::verify() {
314  return verifyAtomicUpdateOp<IntegerType>(getOperation());
315 }
316 
317 ParseResult AtomicIDecrementOp::parse(OpAsmParser &parser,
318  OperationState &result) {
319  return parseAtomicUpdateOp(parser, result, false);
320 }
321 
322 void AtomicIDecrementOp::print(OpAsmPrinter &p) {
323  printAtomicUpdateOp(*this, p);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // spirv.AtomicIIncrementOp
328 //===----------------------------------------------------------------------===//
329 
330 LogicalResult AtomicIIncrementOp::verify() {
331  return verifyAtomicUpdateOp<IntegerType>(getOperation());
332 }
333 
334 ParseResult AtomicIIncrementOp::parse(OpAsmParser &parser,
335  OperationState &result) {
336  return parseAtomicUpdateOp(parser, result, false);
337 }
338 
339 void AtomicIIncrementOp::print(OpAsmPrinter &p) {
340  printAtomicUpdateOp(*this, p);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // spirv.AtomicISubOp
345 //===----------------------------------------------------------------------===//
346 
347 LogicalResult AtomicISubOp::verify() {
348  return verifyAtomicUpdateOp<IntegerType>(getOperation());
349 }
350 
351 ParseResult AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) {
352  return parseAtomicUpdateOp(parser, result, true);
353 }
354 
355 void AtomicISubOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
356 
357 //===----------------------------------------------------------------------===//
358 // spirv.AtomicOrOp
359 //===----------------------------------------------------------------------===//
360 
361 LogicalResult AtomicOrOp::verify() {
362  return verifyAtomicUpdateOp<IntegerType>(getOperation());
363 }
364 
365 ParseResult AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) {
366  return parseAtomicUpdateOp(parser, result, true);
367 }
368 
369 void AtomicOrOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
370 
371 //===----------------------------------------------------------------------===//
372 // spirv.AtomicSMaxOp
373 //===----------------------------------------------------------------------===//
374 
375 LogicalResult AtomicSMaxOp::verify() {
376  return verifyAtomicUpdateOp<IntegerType>(getOperation());
377 }
378 
379 ParseResult AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) {
380  return parseAtomicUpdateOp(parser, result, true);
381 }
382 
383 void AtomicSMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
384 
385 //===----------------------------------------------------------------------===//
386 // spirv.AtomicSMinOp
387 //===----------------------------------------------------------------------===//
388 
389 LogicalResult AtomicSMinOp::verify() {
390  return verifyAtomicUpdateOp<IntegerType>(getOperation());
391 }
392 
393 ParseResult AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) {
394  return parseAtomicUpdateOp(parser, result, true);
395 }
396 
397 void AtomicSMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
398 
399 //===----------------------------------------------------------------------===//
400 // spirv.AtomicUMaxOp
401 //===----------------------------------------------------------------------===//
402 
403 LogicalResult AtomicUMaxOp::verify() {
404  return verifyAtomicUpdateOp<IntegerType>(getOperation());
405 }
406 
407 ParseResult AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) {
408  return parseAtomicUpdateOp(parser, result, true);
409 }
410 
411 void AtomicUMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
412 
413 //===----------------------------------------------------------------------===//
414 // spirv.AtomicUMinOp
415 //===----------------------------------------------------------------------===//
416 
417 LogicalResult AtomicUMinOp::verify() {
418  return verifyAtomicUpdateOp<IntegerType>(getOperation());
419 }
420 
421 ParseResult AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) {
422  return parseAtomicUpdateOp(parser, result, true);
423 }
424 
425 void AtomicUMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
426 
427 //===----------------------------------------------------------------------===//
428 // spirv.AtomicXorOp
429 //===----------------------------------------------------------------------===//
430 
431 LogicalResult AtomicXorOp::verify() {
432  return verifyAtomicUpdateOp<IntegerType>(getOperation());
433 }
434 
435 ParseResult AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) {
436  return parseAtomicUpdateOp(parser, result, true);
437 }
438 
439 void AtomicXorOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
440 
441 } // namespace mlir::spirv
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents success/failure for parsing-like operations that find it important to chain tog...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
constexpr char kEqualSemanticsAttrName[]
constexpr char kMemoryScopeAttrName[]
constexpr char kSemanticsAttrName[]
constexpr char kUnequalSemanticsAttrName[]
static StringRef stringifyTypeName()
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
Definition: AtomicOps.cpp:145
static LogicalResult verifyAtomicUpdateOp(Operation *op)
Definition: AtomicOps.cpp:80
StringRef stringifyTypeName< FloatType >()
Definition: AtomicOps.cpp:74
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
Definition: AtomicOps.cpp:24
StringRef stringifyTypeName< IntegerType >()
Definition: AtomicOps.cpp:69
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
Definition: AtomicOps.cpp:105
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
Definition: AtomicOps.cpp:112
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
Definition: AtomicOps.cpp:55
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.