MLIR  18.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
22 
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/STLForwardCompat.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include <cstddef>
31 #include <optional>
32 
33 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
34 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
35 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
36 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
37 
38 using namespace mlir;
39 using namespace mlir::omp;
40 
41 namespace {
42 struct MemRefPointerLikeModel
43  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
44  MemRefType> {
45  Type getElementType(Type pointer) const {
46  return llvm::cast<MemRefType>(pointer).getElementType();
47  }
48 };
49 
50 struct LLVMPointerPointerLikeModel
51  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
52  LLVM::LLVMPointerType> {
53  Type getElementType(Type pointer) const { return Type(); }
54 };
55 
56 struct OpenMPDialectFoldInterface : public DialectFoldInterface {
58 
59  bool shouldMaterializeInto(Region *region) const final {
60  // Avoid folding constants across target regions
61  return isa<TargetOp>(region->getParentOp());
62  }
63 };
64 } // namespace
65 
66 void OpenMPDialect::initialize() {
67  addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
70  >();
71  addAttributes<
72 #define GET_ATTRDEF_LIST
73 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
74  >();
75  addTypes<
76 #define GET_TYPEDEF_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
78  >();
79 
80  addInterface<OpenMPDialectFoldInterface>();
81  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
82  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
83  *getContext());
84 
85  // Attach default offload module interface to module op to access
86  // offload functionality through
87  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
88  *getContext());
89 
90  // Attach default declare target interfaces to operations which can be marked
91  // as declare target (Global Operations and Functions/Subroutines in dialects
92  // that Fortran (or other languages that lower to MLIR) translates too
93  mlir::LLVM::GlobalOp::attachInterface<
95  *getContext());
96  mlir::LLVM::LLVMFuncOp::attachInterface<
98  *getContext());
99  mlir::func::FuncOp::attachInterface<
101 
102  // Attach default early outlining interface to func ops.
103  mlir::func::FuncOp::attachInterface<
105  mlir::LLVM::LLVMFuncOp::attachInterface<
107  *getContext());
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Parser and printer for Allocate Clause
112 //===----------------------------------------------------------------------===//
113 
114 /// Parse an allocate clause with allocators and a list of operands with types.
115 ///
116 /// allocate-operand-list :: = allocate-operand |
117 /// allocator-operand `,` allocate-operand-list
118 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
119 /// ssa-id-and-type ::= ssa-id `:` type
121  OpAsmParser &parser,
123  SmallVectorImpl<Type> &typesAllocate,
125  SmallVectorImpl<Type> &typesAllocator) {
126 
127  return parser.parseCommaSeparatedList([&]() {
129  Type type;
130  if (parser.parseOperand(operand) || parser.parseColonType(type))
131  return failure();
132  operandsAllocator.push_back(operand);
133  typesAllocator.push_back(type);
134  if (parser.parseArrow())
135  return failure();
136  if (parser.parseOperand(operand) || parser.parseColonType(type))
137  return failure();
138 
139  operandsAllocate.push_back(operand);
140  typesAllocate.push_back(type);
141  return success();
142  });
143 }
144 
145 /// Print allocate clause
147  OperandRange varsAllocate,
148  TypeRange typesAllocate,
149  OperandRange varsAllocator,
150  TypeRange typesAllocator) {
151  for (unsigned i = 0; i < varsAllocate.size(); ++i) {
152  std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
153  p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
154  p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
155  }
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Parser and printer for a clause attribute (StringEnumAttr)
160 //===----------------------------------------------------------------------===//
161 
162 template <typename ClauseAttr>
163 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
164  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
165  StringRef enumStr;
166  SMLoc loc = parser.getCurrentLocation();
167  if (parser.parseKeyword(&enumStr))
168  return failure();
169  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
170  attr = ClauseAttr::get(parser.getContext(), *enumValue);
171  return success();
172  }
173  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
174 }
175 
176 template <typename ClauseAttr>
177 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
178  p << stringifyEnum(attr.getValue());
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // Parser and printer for Linear Clause
183 //===----------------------------------------------------------------------===//
184 
185 /// linear ::= `linear` `(` linear-list `)`
186 /// linear-list := linear-val | linear-val linear-list
187 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
188 static ParseResult
191  SmallVectorImpl<Type> &types,
193  return parser.parseCommaSeparatedList([&]() {
195  Type type;
197  if (parser.parseOperand(var) || parser.parseEqual() ||
198  parser.parseOperand(stepVar) || parser.parseColonType(type))
199  return failure();
200 
201  vars.push_back(var);
202  types.push_back(type);
203  stepVars.push_back(stepVar);
204  return success();
205  });
206 }
207 
208 /// Print Linear Clause
210  ValueRange linearVars, TypeRange linearVarTypes,
211  ValueRange linearStepVars) {
212  size_t linearVarsSize = linearVars.size();
213  for (unsigned i = 0; i < linearVarsSize; ++i) {
214  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
215  p << linearVars[i];
216  if (linearStepVars.size() > i)
217  p << " = " << linearStepVars[i];
218  p << " : " << linearVars[i].getType() << separator;
219  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Verifier for Nontemporal Clause
224 //===----------------------------------------------------------------------===//
225 
226 static LogicalResult
227 verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables) {
228 
229  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
230  DenseSet<Value> nontemporalItems;
231  for (const auto &it : nontemporalVariables)
232  if (!nontemporalItems.insert(it).second)
233  return op->emitOpError() << "nontemporal variable used more than once";
234 
235  return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Parser, verifier and printer for Aligned Clause
240 //===----------------------------------------------------------------------===//
241 static LogicalResult
242 verifyAlignedClause(Operation *op, std::optional<ArrayAttr> alignmentValues,
243  OperandRange alignedVariables) {
244  // Check if number of alignment values equals to number of aligned variables
245  if (!alignedVariables.empty()) {
246  if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
247  return op->emitOpError()
248  << "expected as many alignment values as aligned variables";
249  } else {
250  if (alignmentValues)
251  return op->emitOpError() << "unexpected alignment values attribute";
252  return success();
253  }
254 
255  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
256  DenseSet<Value> alignedItems;
257  for (auto it : alignedVariables)
258  if (!alignedItems.insert(it).second)
259  return op->emitOpError() << "aligned variable used more than once";
260 
261  if (!alignmentValues)
262  return success();
263 
264  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
265  for (unsigned i = 0; i < (*alignmentValues).size(); ++i) {
266  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
267  if (intAttr.getValue().sle(0))
268  return op->emitOpError() << "alignment should be greater than 0";
269  } else {
270  return op->emitOpError() << "expected integer alignment";
271  }
272  }
273 
274  return success();
275 }
276 
277 /// aligned ::= `aligned` `(` aligned-list `)`
278 /// aligned-list := aligned-val | aligned-val aligned-list
279 /// aligned-val := ssa-id-and-type `->` alignment
281  OpAsmParser &parser,
283  SmallVectorImpl<Type> &types, ArrayAttr &alignmentValues) {
284  SmallVector<Attribute> alignmentVec;
285  if (failed(parser.parseCommaSeparatedList([&]() {
286  if (parser.parseOperand(alignedItems.emplace_back()) ||
287  parser.parseColonType(types.emplace_back()) ||
288  parser.parseArrow() ||
289  parser.parseAttribute(alignmentVec.emplace_back())) {
290  return failure();
291  }
292  return success();
293  })))
294  return failure();
295  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
296  alignmentValues = ArrayAttr::get(parser.getContext(), alignments);
297  return success();
298 }
299 
300 /// Print Aligned Clause
302  ValueRange alignedVars,
303  TypeRange alignedVarTypes,
304  std::optional<ArrayAttr> alignmentValues) {
305  for (unsigned i = 0; i < alignedVars.size(); ++i) {
306  if (i != 0)
307  p << ", ";
308  p << alignedVars[i] << " : " << alignedVars[i].getType();
309  p << " -> " << (*alignmentValues)[i];
310  }
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Parser, printer and verifier for Schedule Clause
315 //===----------------------------------------------------------------------===//
316 
317 static ParseResult
319  SmallVectorImpl<SmallString<12>> &modifiers) {
320  if (modifiers.size() > 2)
321  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
322  for (const auto &mod : modifiers) {
323  // Translate the string. If it has no value, then it was not a valid
324  // modifier!
325  auto symbol = symbolizeScheduleModifier(mod);
326  if (!symbol)
327  return parser.emitError(parser.getNameLoc())
328  << " unknown modifier type: " << mod;
329  }
330 
331  // If we have one modifier that is "simd", then stick a "none" modiifer in
332  // index 0.
333  if (modifiers.size() == 1) {
334  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
335  modifiers.push_back(modifiers[0]);
336  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
337  }
338  } else if (modifiers.size() == 2) {
339  // If there are two modifier:
340  // First modifier should not be simd, second one should be simd
341  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
342  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
343  return parser.emitError(parser.getNameLoc())
344  << " incorrect modifier order";
345  }
346  return success();
347 }
348 
349 /// schedule ::= `schedule` `(` sched-list `)`
350 /// sched-list ::= sched-val | sched-val sched-list |
351 /// sched-val `,` sched-modifier
352 /// sched-val ::= sched-with-chunk | sched-wo-chunk
353 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
354 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
355 /// sched-wo-chunk ::= `auto` | `runtime`
356 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
357 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
359  OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
360  ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
361  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
362  StringRef keyword;
363  if (parser.parseKeyword(&keyword))
364  return failure();
365  std::optional<mlir::omp::ClauseScheduleKind> schedule =
366  symbolizeClauseScheduleKind(keyword);
367  if (!schedule)
368  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
369 
370  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
371  switch (*schedule) {
372  case ClauseScheduleKind::Static:
373  case ClauseScheduleKind::Dynamic:
374  case ClauseScheduleKind::Guided:
375  if (succeeded(parser.parseOptionalEqual())) {
376  chunkSize = OpAsmParser::UnresolvedOperand{};
377  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
378  return failure();
379  } else {
380  chunkSize = std::nullopt;
381  }
382  break;
383  case ClauseScheduleKind::Auto:
385  chunkSize = std::nullopt;
386  }
387 
388  // If there is a comma, we have one or more modifiers..
389  SmallVector<SmallString<12>> modifiers;
390  while (succeeded(parser.parseOptionalComma())) {
391  StringRef mod;
392  if (parser.parseKeyword(&mod))
393  return failure();
394  modifiers.push_back(mod);
395  }
396 
397  if (verifyScheduleModifiers(parser, modifiers))
398  return failure();
399 
400  if (!modifiers.empty()) {
401  SMLoc loc = parser.getCurrentLocation();
402  if (std::optional<ScheduleModifier> mod =
403  symbolizeScheduleModifier(modifiers[0])) {
404  scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
405  } else {
406  return parser.emitError(loc, "invalid schedule modifier");
407  }
408  // Only SIMD attribute is allowed here!
409  if (modifiers.size() > 1) {
410  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
411  simdModifier = UnitAttr::get(parser.getBuilder().getContext());
412  }
413  }
414 
415  return success();
416 }
417 
418 /// Print schedule clause
420  ClauseScheduleKindAttr schedAttr,
421  ScheduleModifierAttr modifier, UnitAttr simd,
422  Value scheduleChunkVar,
423  Type scheduleChunkType) {
424  p << stringifyClauseScheduleKind(schedAttr.getValue());
425  if (scheduleChunkVar)
426  p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
427  if (modifier)
428  p << ", " << stringifyScheduleModifier(modifier.getValue());
429  if (simd)
430  p << ", simd";
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // Parser, printer and verifier for ReductionVarList
435 //===----------------------------------------------------------------------===//
436 
437 /// reduction-entry-list ::= reduction-entry
438 /// | reduction-entry-list `,` reduction-entry
439 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
440 static ParseResult
443  SmallVectorImpl<Type> &types,
444  ArrayAttr &redcuctionSymbols) {
445  SmallVector<SymbolRefAttr> reductionVec;
446  if (failed(parser.parseCommaSeparatedList([&]() {
447  if (parser.parseAttribute(reductionVec.emplace_back()) ||
448  parser.parseArrow() ||
449  parser.parseOperand(operands.emplace_back()) ||
450  parser.parseColonType(types.emplace_back()))
451  return failure();
452  return success();
453  })))
454  return failure();
455  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
456  redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
457  return success();
458 }
459 
460 /// Print Reduction clause
462  OperandRange reductionVars,
463  TypeRange reductionTypes,
464  std::optional<ArrayAttr> reductions) {
465  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
466  if (i != 0)
467  p << ", ";
468  p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
469  << reductionVars[i].getType();
470  }
471 }
472 
473 /// Verifies Reduction Clause
475  std::optional<ArrayAttr> reductions,
476  OperandRange reductionVars) {
477  if (!reductionVars.empty()) {
478  if (!reductions || reductions->size() != reductionVars.size())
479  return op->emitOpError()
480  << "expected as many reduction symbol references "
481  "as reduction variables";
482  } else {
483  if (reductions)
484  return op->emitOpError() << "unexpected reduction symbol references";
485  return success();
486  }
487 
488  // TODO: The followings should be done in
489  // SymbolUserOpInterface::verifySymbolUses.
490  DenseSet<Value> accumulators;
491  for (auto args : llvm::zip(reductionVars, *reductions)) {
492  Value accum = std::get<0>(args);
493 
494  if (!accumulators.insert(accum).second)
495  return op->emitOpError() << "accumulator variable used more than once";
496 
497  Type varType = accum.getType();
498  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
499  auto decl =
500  SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
501  if (!decl)
502  return op->emitOpError() << "expected symbol reference " << symbolRef
503  << " to point to a reduction declaration";
504 
505  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
506  return op->emitOpError()
507  << "expected accumulator (" << varType
508  << ") to be the same type as reduction declaration ("
509  << decl.getAccumulatorType() << ")";
510  }
511 
512  return success();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // Parser, printer and verifier for DependVarList
517 //===----------------------------------------------------------------------===//
518 
519 /// depend-entry-list ::= depend-entry
520 /// | depend-entry-list `,` depend-entry
521 /// depend-entry ::= depend-kind `->` ssa-id `:` type
522 static ParseResult
525  SmallVectorImpl<Type> &types, ArrayAttr &dependsArray) {
527  if (failed(parser.parseCommaSeparatedList([&]() {
528  StringRef keyword;
529  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
530  parser.parseOperand(operands.emplace_back()) ||
531  parser.parseColonType(types.emplace_back()))
532  return failure();
533  if (std::optional<ClauseTaskDepend> keywordDepend =
534  (symbolizeClauseTaskDepend(keyword)))
535  dependVec.emplace_back(
536  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
537  else
538  return failure();
539  return success();
540  })))
541  return failure();
542  SmallVector<Attribute> depends(dependVec.begin(), dependVec.end());
543  dependsArray = ArrayAttr::get(parser.getContext(), depends);
544  return success();
545 }
546 
547 /// Print Depend clause
549  OperandRange dependVars, TypeRange dependTypes,
550  std::optional<ArrayAttr> depends) {
551 
552  for (unsigned i = 0, e = depends->size(); i < e; ++i) {
553  if (i != 0)
554  p << ", ";
555  p << stringifyClauseTaskDepend(
556  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
557  .getValue())
558  << " -> " << dependVars[i] << " : " << dependTypes[i];
559  }
560 }
561 
562 /// Verifies Depend clause
564  std::optional<ArrayAttr> depends,
565  OperandRange dependVars) {
566  if (!dependVars.empty()) {
567  if (!depends || depends->size() != dependVars.size())
568  return op->emitOpError() << "expected as many depend values"
569  " as depend variables";
570  } else {
571  if (depends)
572  return op->emitOpError() << "unexpected depend values";
573  return success();
574  }
575 
576  return success();
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // Parser, printer and verifier for Synchronization Hint (2.17.12)
581 //===----------------------------------------------------------------------===//
582 
583 /// Parses a Synchronization Hint clause. The value of hint is an integer
584 /// which is a combination of different hints from `omp_sync_hint_t`.
585 ///
586 /// hint-clause = `hint` `(` hint-value `)`
588  IntegerAttr &hintAttr) {
589  StringRef hintKeyword;
590  int64_t hint = 0;
591  if (succeeded(parser.parseOptionalKeyword("none"))) {
592  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
593  return success();
594  }
595  auto parseKeyword = [&]() -> ParseResult {
596  if (failed(parser.parseKeyword(&hintKeyword)))
597  return failure();
598  if (hintKeyword == "uncontended")
599  hint |= 1;
600  else if (hintKeyword == "contended")
601  hint |= 2;
602  else if (hintKeyword == "nonspeculative")
603  hint |= 4;
604  else if (hintKeyword == "speculative")
605  hint |= 8;
606  else
607  return parser.emitError(parser.getCurrentLocation())
608  << hintKeyword << " is not a valid hint";
609  return success();
610  };
611  if (parser.parseCommaSeparatedList(parseKeyword))
612  return failure();
613  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
614  return success();
615 }
616 
617 /// Prints a Synchronization Hint clause
619  IntegerAttr hintAttr) {
620  int64_t hint = hintAttr.getInt();
621 
622  if (hint == 0) {
623  p << "none";
624  return;
625  }
626 
627  // Helper function to get n-th bit from the right end of `value`
628  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
629 
630  bool uncontended = bitn(hint, 0);
631  bool contended = bitn(hint, 1);
632  bool nonspeculative = bitn(hint, 2);
633  bool speculative = bitn(hint, 3);
634 
636  if (uncontended)
637  hints.push_back("uncontended");
638  if (contended)
639  hints.push_back("contended");
640  if (nonspeculative)
641  hints.push_back("nonspeculative");
642  if (speculative)
643  hints.push_back("speculative");
644 
645  llvm::interleaveComma(hints, p);
646 }
647 
648 /// Verifies a synchronization hint clause
650 
651  // Helper function to get n-th bit from the right end of `value`
652  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
653 
654  bool uncontended = bitn(hint, 0);
655  bool contended = bitn(hint, 1);
656  bool nonspeculative = bitn(hint, 2);
657  bool speculative = bitn(hint, 3);
658 
659  if (uncontended && contended)
660  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
661  "omp_sync_hint_contended cannot be combined";
662  if (nonspeculative && speculative)
663  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
664  "omp_sync_hint_speculative cannot be combined.";
665  return success();
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // Parser, printer and verifier for Target
670 //===----------------------------------------------------------------------===//
671 
672 // Helper function to get bitwise AND of `value` and 'flag'
673 uint64_t mapTypeToBitFlag(uint64_t value,
674  llvm::omp::OpenMPOffloadMappingFlags flag) {
675  return value & llvm::to_underlying(flag);
676 }
677 
678 /// Parses a map_entries map type from a string format back into its numeric
679 /// value.
680 ///
681 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
682 /// `to` | `from` | `delete` `)` )+ `)` )
683 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
684  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
685  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
686 
687  // This simply verifies the correct keyword is read in, the
688  // keyword itself is stored inside of the operation
689  auto parseTypeAndMod = [&]() -> ParseResult {
690  StringRef mapTypeMod;
691  if (parser.parseKeyword(&mapTypeMod))
692  return failure();
693 
694  if (mapTypeMod == "always")
695  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
696 
697  if (mapTypeMod == "implicit")
698  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
699 
700  if (mapTypeMod == "close")
701  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
702 
703  if (mapTypeMod == "present")
704  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
705 
706  if (mapTypeMod == "to")
707  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
708 
709  if (mapTypeMod == "from")
710  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
711 
712  if (mapTypeMod == "tofrom")
713  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
714  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
715 
716  if (mapTypeMod == "delete")
717  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
718 
719  return success();
720  };
721 
722  if (parser.parseCommaSeparatedList(parseTypeAndMod))
723  return failure();
724 
725  mapType = parser.getBuilder().getIntegerAttr(
726  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
727  llvm::to_underlying(mapTypeBits));
728 
729  return success();
730 }
731 
732 /// Prints a map_entries map type from its numeric value out into its string
733 /// format.
735  IntegerAttr mapType) {
736  uint64_t mapTypeBits = mapType.getUInt();
737 
738  bool emitAllocRelease = true;
740 
741  // handling of always, close, present placed at the beginning of the string
742  // to aid readability
743  if (mapTypeToBitFlag(mapTypeBits,
744  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
745  mapTypeStrs.push_back("always");
746  if (mapTypeToBitFlag(mapTypeBits,
747  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
748  mapTypeStrs.push_back("implicit");
749  if (mapTypeToBitFlag(mapTypeBits,
750  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
751  mapTypeStrs.push_back("close");
752  if (mapTypeToBitFlag(mapTypeBits,
753  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
754  mapTypeStrs.push_back("present");
755 
756  // special handling of to/from/tofrom/delete and release/alloc, release +
757  // alloc are the abscense of one of the other flags, whereas tofrom requires
758  // both the to and from flag to be set.
759  bool to = mapTypeToBitFlag(mapTypeBits,
760  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
761  bool from = mapTypeToBitFlag(
762  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
763  if (to && from) {
764  emitAllocRelease = false;
765  mapTypeStrs.push_back("tofrom");
766  } else if (from) {
767  emitAllocRelease = false;
768  mapTypeStrs.push_back("from");
769  } else if (to) {
770  emitAllocRelease = false;
771  mapTypeStrs.push_back("to");
772  }
773  if (mapTypeToBitFlag(mapTypeBits,
774  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
775  emitAllocRelease = false;
776  mapTypeStrs.push_back("delete");
777  }
778  if (emitAllocRelease)
779  mapTypeStrs.push_back("exit_release_or_enter_alloc");
780 
781  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
782  p << mapTypeStrs[i];
783  if (i + 1 < mapTypeStrs.size()) {
784  p << ", ";
785  }
786  }
787 }
788 
789 static ParseResult
792  SmallVectorImpl<Type> &mapOperandTypes) {
795  Type argType;
796  auto parseEntries = [&]() -> ParseResult {
797  if (parser.parseOperand(arg) || parser.parseArrow() ||
798  parser.parseOperand(blockArg))
799  return failure();
800  mapOperands.push_back(arg);
801  return success();
802  };
803 
804  auto parseTypes = [&]() -> ParseResult {
805  if (parser.parseType(argType))
806  return failure();
807  mapOperandTypes.push_back(argType);
808  return success();
809  };
810 
811  if (parser.parseCommaSeparatedList(parseEntries))
812  return failure();
813 
814  if (parser.parseColon())
815  return failure();
816 
817  if (parser.parseCommaSeparatedList(parseTypes))
818  return failure();
819 
820  return success();
821 }
822 
824  OperandRange mapOperands,
825  TypeRange mapOperandTypes) {
826  auto &region = op->getRegion(0);
827  unsigned argIndex = 0;
828 
829  for (const auto &mapOp : mapOperands) {
830  const auto &blockArg = region.front().getArgument(argIndex);
831  p << mapOp << " -> " << blockArg;
832  argIndex++;
833  if (argIndex < mapOperands.size())
834  p << ", ";
835  }
836  p << " : ";
837 
838  argIndex = 0;
839  for (const auto &mapType : mapOperandTypes) {
840  p << mapType;
841  argIndex++;
842  if (argIndex < mapOperands.size())
843  p << ", ";
844  }
845 }
846 
848  VariableCaptureKindAttr mapCaptureType) {
849  std::string typeCapStr;
850  llvm::raw_string_ostream typeCap(typeCapStr);
851  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
852  typeCap << "ByRef";
853  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
854  typeCap << "ByCopy";
855  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
856  typeCap << "VLAType";
857  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
858  typeCap << "This";
859  p << typeCap.str();
860 }
861 
863  VariableCaptureKindAttr &mapCapture) {
864  StringRef mapCaptureKey;
865  if (parser.parseKeyword(&mapCaptureKey))
866  return failure();
867 
868  if (mapCaptureKey == "This")
870  parser.getContext(), mlir::omp::VariableCaptureKind::This);
871  if (mapCaptureKey == "ByRef")
873  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
874  if (mapCaptureKey == "ByCopy")
876  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
877  if (mapCaptureKey == "VLAType")
879  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
880 
881  return success();
882 }
883 
885 
886  for (auto mapOp : mapOperands) {
887  if (!mapOp.getDefiningOp())
888  emitError(op->getLoc(), "missing map operation");
889 
890  if (auto MapInfoOp =
891  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
892 
893  if (!MapInfoOp.getMapType().has_value())
894  emitError(op->getLoc(), "missing map type for map operand");
895 
896  if (!MapInfoOp.getMapCaptureType().has_value())
897  emitError(op->getLoc(), "missing map capture type for map operand");
898 
899  uint64_t mapTypeBits = MapInfoOp.getMapType().value();
900 
901  bool to = mapTypeToBitFlag(
902  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
903  bool from = mapTypeToBitFlag(
904  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
905  bool del = mapTypeToBitFlag(
906  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
907 
908  if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
909  return emitError(op->getLoc(),
910  "to, from, tofrom and alloc map types are permitted");
911 
912  if (isa<EnterDataOp>(op) && (from || del))
913  return emitError(op->getLoc(), "to and alloc map types are permitted");
914 
915  if (isa<ExitDataOp>(op) && to)
916  return emitError(op->getLoc(),
917  "from, release and delete map types are permitted");
918  } else {
919  emitError(op->getLoc(), "map argument is not a map entry operation");
920  }
921  }
922 
923  return success();
924 }
925 
927  if (getMapOperands().empty() && getUseDevicePtr().empty() &&
928  getUseDeviceAddr().empty()) {
929  return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or "
930  "useDeviceAddr operand must be present");
931  }
932  return verifyMapClause(*this, getMapOperands());
933 }
934 
936  return verifyMapClause(*this, getMapOperands());
937 }
938 
940  return verifyMapClause(*this, getMapOperands());
941 }
942 
944  return verifyMapClause(*this, getMapOperands());
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // ParallelOp
949 //===----------------------------------------------------------------------===//
950 
951 void ParallelOp::build(OpBuilder &builder, OperationState &state,
952  ArrayRef<NamedAttribute> attributes) {
953  ParallelOp::build(
954  builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
955  /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
956  /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
957  /*proc_bind_val=*/nullptr);
958  state.addAttributes(attributes);
959 }
960 
962  if (getAllocateVars().size() != getAllocatorsVars().size())
963  return emitError(
964  "expected equal sizes for allocate and allocator variables");
965  return verifyReductionVarList(*this, getReductions(), getReductionVars());
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // TeamsOp
970 //===----------------------------------------------------------------------===//
971 
973  while ((op = op->getParentOp()))
974  if (isa<OpenMPDialect>(op->getDialect()))
975  return false;
976  return true;
977 }
978 
980  // Check parent region
981  // TODO If nested inside of a target region, also check that it does not
982  // contain any statements, declarations or directives other than this
983  // omp.teams construct. The issue is how to support the initialization of
984  // this operation's own arguments (allow SSA values across omp.target?).
985  Operation *op = getOperation();
986  if (!isa<TargetOp>(op->getParentOp()) &&
988  return emitError("expected to be nested inside of omp.target or not nested "
989  "in any OpenMP dialect operations");
990 
991  // Check for num_teams clause restrictions
992  if (auto numTeamsLowerBound = getNumTeamsLower()) {
993  auto numTeamsUpperBound = getNumTeamsUpper();
994  if (!numTeamsUpperBound)
995  return emitError("expected num_teams upper bound to be defined if the "
996  "lower bound is defined");
997  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
998  return emitError(
999  "expected num_teams upper bound and lower bound to be the same type");
1000  }
1001 
1002  // Check for allocate clause restrictions
1003  if (getAllocateVars().size() != getAllocatorsVars().size())
1004  return emitError(
1005  "expected equal sizes for allocate and allocator variables");
1006 
1007  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // Verifier for SectionsOp
1012 //===----------------------------------------------------------------------===//
1013 
1015  if (getAllocateVars().size() != getAllocatorsVars().size())
1016  return emitError(
1017  "expected equal sizes for allocate and allocator variables");
1018 
1019  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1020 }
1021 
1022 LogicalResult SectionsOp::verifyRegions() {
1023  for (auto &inst : *getRegion().begin()) {
1024  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1025  return emitOpError()
1026  << "expected omp.section op or terminator op inside region";
1027  }
1028  }
1029 
1030  return success();
1031 }
1032 
1034  // Check for allocate clause restrictions
1035  if (getAllocateVars().size() != getAllocatorsVars().size())
1036  return emitError(
1037  "expected equal sizes for allocate and allocator variables");
1038 
1039  return success();
1040 }
1041 
1042 //===----------------------------------------------------------------------===//
1043 // WsLoopOp
1044 //===----------------------------------------------------------------------===//
1045 
1046 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
1047 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1048 /// steps := `step` `(`ssa-id-list`)`
1054  SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
1055  // Parse an opening `(` followed by induction variables followed by `)`
1057  Type loopVarType;
1059  parser.parseColonType(loopVarType) ||
1060  // Parse loop bounds.
1061  parser.parseEqual() ||
1062  parser.parseOperandList(lowerBound, ivs.size(),
1064  parser.parseKeyword("to") ||
1065  parser.parseOperandList(upperBound, ivs.size(),
1067  return failure();
1068 
1069  if (succeeded(parser.parseOptionalKeyword("inclusive")))
1070  inclusive = UnitAttr::get(parser.getBuilder().getContext());
1071 
1072  // Parse step values.
1073  if (parser.parseKeyword("step") ||
1074  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
1075  return failure();
1076 
1077  // Now parse the body.
1078  loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
1079  for (auto &iv : ivs)
1080  iv.type = loopVarType;
1081  return parser.parseRegion(region, ivs);
1082 }
1083 
1085  ValueRange lowerBound, ValueRange upperBound,
1086  ValueRange steps, TypeRange loopVarTypes,
1087  UnitAttr inclusive) {
1088  auto args = region.front().getArguments();
1089  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
1090  << ") to (" << upperBound << ") ";
1091  if (inclusive)
1092  p << "inclusive ";
1093  p << "step (" << steps << ") ";
1094  p.printRegion(region, /*printEntryBlockArgs=*/false);
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // Verifier for Simd construct [2.9.3.1]
1099 //===----------------------------------------------------------------------===//
1100 
1102  if (this->getLowerBound().empty()) {
1103  return emitOpError() << "empty lowerbound for simd loop operation";
1104  }
1105  if (this->getSimdlen().has_value() && this->getSafelen().has_value() &&
1106  this->getSimdlen().value() > this->getSafelen().value()) {
1107  return emitOpError()
1108  << "simdlen clause and safelen clause are both present, but the "
1109  "simdlen value is not less than or equal to safelen value";
1110  }
1111  if (verifyAlignedClause(*this, this->getAlignmentValues(),
1112  this->getAlignedVars())
1113  .failed())
1114  return failure();
1115  if (verifyNontemporalClause(*this, this->getNontemporalVars()).failed())
1116  return failure();
1117  return success();
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // ReductionOp
1122 //===----------------------------------------------------------------------===//
1123 
1125  Region &region) {
1126  if (parser.parseOptionalKeyword("atomic"))
1127  return success();
1128  return parser.parseRegion(region);
1129 }
1130 
1132  ReductionDeclareOp op, Region &region) {
1133  if (region.empty())
1134  return;
1135  printer << "atomic ";
1136  printer.printRegion(region);
1137 }
1138 
1139 LogicalResult ReductionDeclareOp::verifyRegions() {
1140  if (getInitializerRegion().empty())
1141  return emitOpError() << "expects non-empty initializer region";
1142  Block &initializerEntryBlock = getInitializerRegion().front();
1143  if (initializerEntryBlock.getNumArguments() != 1 ||
1144  initializerEntryBlock.getArgument(0).getType() != getType()) {
1145  return emitOpError() << "expects initializer region with one argument "
1146  "of the reduction type";
1147  }
1148 
1149  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1150  if (yieldOp.getResults().size() != 1 ||
1151  yieldOp.getResults().getTypes()[0] != getType())
1152  return emitOpError() << "expects initializer region to yield a value "
1153  "of the reduction type";
1154  }
1155 
1156  if (getReductionRegion().empty())
1157  return emitOpError() << "expects non-empty reduction region";
1158  Block &reductionEntryBlock = getReductionRegion().front();
1159  if (reductionEntryBlock.getNumArguments() != 2 ||
1160  reductionEntryBlock.getArgumentTypes()[0] !=
1161  reductionEntryBlock.getArgumentTypes()[1] ||
1162  reductionEntryBlock.getArgumentTypes()[0] != getType())
1163  return emitOpError() << "expects reduction region with two arguments of "
1164  "the reduction type";
1165  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1166  if (yieldOp.getResults().size() != 1 ||
1167  yieldOp.getResults().getTypes()[0] != getType())
1168  return emitOpError() << "expects reduction region to yield a value "
1169  "of the reduction type";
1170  }
1171 
1172  if (getAtomicReductionRegion().empty())
1173  return success();
1174 
1175  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1176  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1177  atomicReductionEntryBlock.getArgumentTypes()[0] !=
1178  atomicReductionEntryBlock.getArgumentTypes()[1])
1179  return emitOpError() << "expects atomic reduction region with two "
1180  "arguments of the same type";
1181  auto ptrType = llvm::dyn_cast<PointerLikeType>(
1182  atomicReductionEntryBlock.getArgumentTypes()[0]);
1183  if (!ptrType ||
1184  (ptrType.getElementType() && ptrType.getElementType() != getType()))
1185  return emitOpError() << "expects atomic reduction region arguments to "
1186  "be accumulators containing the reduction type";
1187  return success();
1188 }
1189 
1191  auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
1192  if (!op)
1193  return emitOpError() << "must be used within an operation supporting "
1194  "reduction clause interface";
1195  while (op) {
1196  for (const auto &var :
1197  cast<ReductionClauseInterface>(op).getAllReductionVars())
1198  if (var == getAccumulator())
1199  return success();
1200  op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
1201  }
1202  return emitOpError() << "the accumulator is not used by the parent";
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // TaskOp
1207 //===----------------------------------------------------------------------===//
1209  LogicalResult verifyDependVars =
1210  verifyDependVarList(*this, getDepends(), getDependVars());
1211  return failed(verifyDependVars)
1212  ? verifyDependVars
1213  : verifyReductionVarList(*this, getInReductions(),
1214  getInReductionVars());
1215 }
1216 
1217 //===----------------------------------------------------------------------===//
1218 // TaskGroupOp
1219 //===----------------------------------------------------------------------===//
1221  return verifyReductionVarList(*this, getTaskReductions(),
1222  getTaskReductionVars());
1223 }
1224 
1225 //===----------------------------------------------------------------------===//
1226 // TaskLoopOp
1227 //===----------------------------------------------------------------------===//
1228 SmallVector<Value> TaskLoopOp::getAllReductionVars() {
1229  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
1230  getInReductionVars().end());
1231  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1232  getReductionVars().end());
1233  return allReductionNvars;
1234 }
1235 
1237  if (getAllocateVars().size() != getAllocatorsVars().size())
1238  return emitError(
1239  "expected equal sizes for allocate and allocator variables");
1240  if (failed(
1241  verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
1242  failed(verifyReductionVarList(*this, getInReductions(),
1243  getInReductionVars())))
1244  return failure();
1245 
1246  if (!getReductionVars().empty() && getNogroup())
1247  return emitError("if a reduction clause is present on the taskloop "
1248  "directive, the nogroup clause must not be specified");
1249  for (auto var : getReductionVars()) {
1250  if (llvm::is_contained(getInReductionVars(), var))
1251  return emitError("the same list item cannot appear in both a reduction "
1252  "and an in_reduction clause");
1253  }
1254 
1255  if (getGrainSize() && getNumTasks()) {
1256  return emitError(
1257  "the grainsize clause and num_tasks clause are mutually exclusive and "
1258  "may not appear on the same taskloop directive");
1259  }
1260  return success();
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // WsLoopOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1268  ValueRange lowerBound, ValueRange upperBound,
1269  ValueRange step, ArrayRef<NamedAttribute> attributes) {
1270  build(builder, state, lowerBound, upperBound, step,
1271  /*linear_vars=*/ValueRange(),
1272  /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
1273  /*reductions=*/nullptr, /*schedule_val=*/nullptr,
1274  /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
1275  /*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,
1276  /*order_val=*/nullptr, /*inclusive=*/false);
1277  state.addAttributes(attributes);
1278 }
1279 
1281  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1282 }
1283 
1284 //===----------------------------------------------------------------------===//
1285 // Verifier for critical construct (2.17.1)
1286 //===----------------------------------------------------------------------===//
1287 
1289  return verifySynchronizationHint(*this, getHintVal());
1290 }
1291 
1292 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1293  if (getNameAttr()) {
1294  SymbolRefAttr symbolRef = getNameAttr();
1295  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
1296  *this, symbolRef);
1297  if (!decl) {
1298  return emitOpError() << "expected symbol reference " << symbolRef
1299  << " to point to a critical declaration";
1300  }
1301  }
1302 
1303  return success();
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // Verifier for ordered construct
1308 //===----------------------------------------------------------------------===//
1309 
1311  auto container = (*this)->getParentOfType<WsLoopOp>();
1312  if (!container || !container.getOrderedValAttr() ||
1313  container.getOrderedValAttr().getInt() == 0)
1314  return emitOpError() << "ordered depend directive must be closely "
1315  << "nested inside a worksharing-loop with ordered "
1316  << "clause with parameter present";
1317 
1318  if (container.getOrderedValAttr().getInt() != (int64_t)*getNumLoopsVal())
1319  return emitOpError() << "number of variables in depend clause does not "
1320  << "match number of iteration variables in the "
1321  << "doacross loop";
1322 
1323  return success();
1324 }
1325 
1327  // TODO: The code generation for ordered simd directive is not supported yet.
1328  if (getSimd())
1329  return failure();
1330 
1331  if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
1332  if (!container.getOrderedValAttr() ||
1333  container.getOrderedValAttr().getInt() != 0)
1334  return emitOpError() << "ordered region must be closely nested inside "
1335  << "a worksharing-loop region with an ordered "
1336  << "clause without parameter present";
1337  }
1338 
1339  return success();
1340 }
1341 
1342 //===----------------------------------------------------------------------===//
1343 // Verifier for AtomicReadOp
1344 //===----------------------------------------------------------------------===//
1345 
1347  if (verifyCommon().failed())
1348  return mlir::failure();
1349 
1350  if (auto mo = getMemoryOrderVal()) {
1351  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1352  *mo == ClauseMemoryOrderKind::Release) {
1353  return emitError(
1354  "memory-order must not be acq_rel or release for atomic reads");
1355  }
1356  }
1357  return verifySynchronizationHint(*this, getHintVal());
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // Verifier for AtomicWriteOp
1362 //===----------------------------------------------------------------------===//
1363 
1365  if (verifyCommon().failed())
1366  return mlir::failure();
1367 
1368  if (auto mo = getMemoryOrderVal()) {
1369  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1370  *mo == ClauseMemoryOrderKind::Acquire) {
1371  return emitError(
1372  "memory-order must not be acq_rel or acquire for atomic writes");
1373  }
1374  }
1375  return verifySynchronizationHint(*this, getHintVal());
1376 }
1377 
1378 //===----------------------------------------------------------------------===//
1379 // Verifier for AtomicUpdateOp
1380 //===----------------------------------------------------------------------===//
1381 
1382 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1383  PatternRewriter &rewriter) {
1384  if (op.isNoOp()) {
1385  rewriter.eraseOp(op);
1386  return success();
1387  }
1388  if (Value writeVal = op.getWriteOpVal()) {
1389  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
1390  op.getHintValAttr(),
1391  op.getMemoryOrderValAttr());
1392  return success();
1393  }
1394  return failure();
1395 }
1396 
1398  if (verifyCommon().failed())
1399  return mlir::failure();
1400 
1401  if (auto mo = getMemoryOrderVal()) {
1402  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1403  *mo == ClauseMemoryOrderKind::Acquire) {
1404  return emitError(
1405  "memory-order must not be acq_rel or acquire for atomic updates");
1406  }
1407  }
1408 
1409  return verifySynchronizationHint(*this, getHintVal());
1410 }
1411 
1412 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // Verifier for AtomicCaptureOp
1416 //===----------------------------------------------------------------------===//
1417 
1418 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1419  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1420  return op;
1421  return dyn_cast<AtomicReadOp>(getSecondOp());
1422 }
1423 
1424 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1425  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1426  return op;
1427  return dyn_cast<AtomicWriteOp>(getSecondOp());
1428 }
1429 
1430 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1431  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1432  return op;
1433  return dyn_cast<AtomicUpdateOp>(getSecondOp());
1434 }
1435 
1437  return verifySynchronizationHint(*this, getHintVal());
1438 }
1439 
1440 LogicalResult AtomicCaptureOp::verifyRegions() {
1441  if (verifyRegionsCommon().failed())
1442  return mlir::failure();
1443 
1444  if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
1445  return emitOpError(
1446  "operations inside capture region must not have hint clause");
1447 
1448  if (getFirstOp()->getAttr("memory_order_val") ||
1449  getSecondOp()->getAttr("memory_order_val"))
1450  return emitOpError(
1451  "operations inside capture region must not have memory_order clause");
1452  return success();
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // Verifier for CancelOp
1457 //===----------------------------------------------------------------------===//
1458 
1460  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1461  Operation *parentOp = (*this)->getParentOp();
1462 
1463  if (!parentOp) {
1464  return emitOpError() << "must be used within a region supporting "
1465  "cancel directive";
1466  }
1467 
1468  if ((cct == ClauseCancellationConstructType::Parallel) &&
1469  !isa<ParallelOp>(parentOp)) {
1470  return emitOpError() << "cancel parallel must appear "
1471  << "inside a parallel region";
1472  }
1473  if (cct == ClauseCancellationConstructType::Loop) {
1474  if (!isa<WsLoopOp>(parentOp)) {
1475  return emitOpError() << "cancel loop must appear "
1476  << "inside a worksharing-loop region";
1477  }
1478  if (cast<WsLoopOp>(parentOp).getNowaitAttr()) {
1479  return emitError() << "A worksharing construct that is canceled "
1480  << "must not have a nowait clause";
1481  }
1482  if (cast<WsLoopOp>(parentOp).getOrderedValAttr()) {
1483  return emitError() << "A worksharing construct that is canceled "
1484  << "must not have an ordered clause";
1485  }
1486 
1487  } else if (cct == ClauseCancellationConstructType::Sections) {
1488  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1489  return emitOpError() << "cancel sections must appear "
1490  << "inside a sections region";
1491  }
1492  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
1493  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
1494  return emitError() << "A sections construct that is canceled "
1495  << "must not have a nowait clause";
1496  }
1497  }
1498  // TODO : Add more when we support taskgroup.
1499  return success();
1500 }
1501 //===----------------------------------------------------------------------===//
1502 // Verifier for CancelOp
1503 //===----------------------------------------------------------------------===//
1504 
1506  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1507  Operation *parentOp = (*this)->getParentOp();
1508 
1509  if (!parentOp) {
1510  return emitOpError() << "must be used within a region supporting "
1511  "cancellation point directive";
1512  }
1513 
1514  if ((cct == ClauseCancellationConstructType::Parallel) &&
1515  !(isa<ParallelOp>(parentOp))) {
1516  return emitOpError() << "cancellation point parallel must appear "
1517  << "inside a parallel region";
1518  }
1519  if ((cct == ClauseCancellationConstructType::Loop) &&
1520  !isa<WsLoopOp>(parentOp)) {
1521  return emitOpError() << "cancellation point loop must appear "
1522  << "inside a worksharing-loop region";
1523  }
1524  if ((cct == ClauseCancellationConstructType::Sections) &&
1525  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1526  return emitOpError() << "cancellation point sections must appear "
1527  << "inside a sections region";
1528  }
1529  // TODO : Add more when we support taskgroup.
1530  return success();
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // DataBoundsOp
1535 //===----------------------------------------------------------------------===//
1536 
1538  auto extent = getExtent();
1539  auto upperbound = getUpperBound();
1540  if (!extent && !upperbound)
1541  return emitError("expected extent or upperbound.");
1542  return success();
1543 }
1544 
1545 #define GET_ATTRDEF_CLASSES
1546 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1547 
1548 #define GET_OP_CLASSES
1549 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1550 
1551 #define GET_TYPEDEF_CLASSES
1552 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:702
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:694
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars)
Verifies Reduction Clause.
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printAtomicReductionRegion(OpAsmPrinter &printer, ReductionDeclareOp op, Region &region)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &redcuctionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, UnitAttr inclusive)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region &region)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:143
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
Define a fold interface to allow for dialects to control specific aspects of the folding behavior for...
DialectFoldInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:206
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.