MLIR  22.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/SymbolTable.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include <cstddef>
38 #include <iterator>
39 #include <optional>
40 #include <variant>
41 
42 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
46 
47 using namespace mlir;
48 using namespace mlir::omp;
49 
50 static ArrayAttr makeArrayAttr(MLIRContext *context,
52  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
53 }
54 
55 static DenseBoolArrayAttr
57  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
58 }
59 
60 static DenseI64ArrayAttr
62  return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
63 }
64 
65 namespace {
66 struct MemRefPointerLikeModel
67  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
68  MemRefType> {
69  Type getElementType(Type pointer) const {
70  return llvm::cast<MemRefType>(pointer).getElementType();
71  }
72 };
73 
74 struct LLVMPointerPointerLikeModel
75  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76  LLVM::LLVMPointerType> {
77  Type getElementType(Type pointer) const { return Type(); }
78 };
79 } // namespace
80 
81 /// Generate a name of a canonical loop nest of the format
82 /// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
83 /// argument index of an operation that has multiple regions, if the operation
84 /// has multiple regions.
85 /// `_s<idx>` identifies the position of an operation within a region, where
86 /// only operations that may potentially contain loops ("container operations"
87 /// i.e. have region arguments) are counted. Again, it is omitted if there is
88 /// only one such operation in a region. If there are canonical loops nested
89 /// inside each other, also may also use the format `_d<num>` where <num> is the
90 /// nesting depth of the loop.
91 ///
92 /// The generated name is a best-effort to make canonical loop unique within an
93 /// SSA namespace. This also means that regions with IsolatedFromAbove property
94 /// do not consider any parents or siblings.
95 static std::string generateLoopNestingName(StringRef prefix,
96  CanonicalLoopOp op) {
97  struct Component {
98  /// If true, this component describes a region operand of an operation (the
99  /// operand's owner) If false, this component describes an operation located
100  /// in a parent region
101  bool isRegionArgOfOp;
102  bool skip = false;
103  bool isUnique = false;
104 
105  size_t idx;
106  Operation *op;
107  Region *parentRegion;
108  size_t loopDepth;
109 
110  Operation *&getOwnerOp() {
111  assert(isRegionArgOfOp && "Must describe a region operand");
112  return op;
113  }
114  size_t &getArgIdx() {
115  assert(isRegionArgOfOp && "Must describe a region operand");
116  return idx;
117  }
118 
119  Operation *&getContainerOp() {
120  assert(!isRegionArgOfOp && "Must describe a operation of a region");
121  return op;
122  }
123  size_t &getOpPos() {
124  assert(!isRegionArgOfOp && "Must describe a operation of a region");
125  return idx;
126  }
127  bool isLoopOp() const {
128  assert(!isRegionArgOfOp && "Must describe a operation of a region");
129  return isa<CanonicalLoopOp>(op);
130  }
131  Region *&getParentRegion() {
132  assert(!isRegionArgOfOp && "Must describe a operation of a region");
133  return parentRegion;
134  }
135  size_t &getLoopDepth() {
136  assert(!isRegionArgOfOp && "Must describe a operation of a region");
137  return loopDepth;
138  }
139 
140  void skipIf(bool v = true) { skip = skip || v; }
141  };
142 
143  // List of ancestors, from inner to outer.
144  // Alternates between
145  // * region argument of an operation
146  // * operation within a region
147  SmallVector<Component> components;
148 
149  // Gather a list of parent regions and operations, and the position within
150  // their parent
151  Operation *o = op.getOperation();
152  while (o) {
153  // Operation within a region
154  Region *r = o->getParentRegion();
155  if (!r)
156  break;
157 
158  llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
159  size_t idx = 0;
160  bool found = false;
161  size_t sequentialIdx = -1;
162  bool isOnlyContainerOp = true;
163  for (Block *b : traversal) {
164  for (Operation &op : *b) {
165  if (&op == o && !found) {
166  sequentialIdx = idx;
167  found = true;
168  }
169  if (op.getNumRegions()) {
170  idx += 1;
171  if (idx > 1)
172  isOnlyContainerOp = false;
173  }
174  if (found && !isOnlyContainerOp)
175  break;
176  }
177  }
178 
179  Component &containerOpInRegion = components.emplace_back();
180  containerOpInRegion.isRegionArgOfOp = false;
181  containerOpInRegion.isUnique = isOnlyContainerOp;
182  containerOpInRegion.getContainerOp() = o;
183  containerOpInRegion.getOpPos() = sequentialIdx;
184  containerOpInRegion.getParentRegion() = r;
185 
186  Operation *parent = r->getParentOp();
187 
188  // Region argument of an operation
189  Component &regionArgOfOperation = components.emplace_back();
190  regionArgOfOperation.isRegionArgOfOp = true;
191  regionArgOfOperation.isUnique = true;
192  regionArgOfOperation.getArgIdx() = 0;
193  regionArgOfOperation.getOwnerOp() = parent;
194 
195  // The IsolatedFromAbove trait of the parent operation implies that each
196  // individual region argument has its own separate namespace, so no
197  // ambiguity.
198  if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
199  break;
200 
201  // Component only needed if operation has multiple region operands. Region
202  // arguments may be optional, but we currently do not consider this.
203  if (parent->getRegions().size() > 1) {
204  auto getRegionIndex = [](Operation *o, Region *r) {
205  for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
206  if (&region == r)
207  return idx;
208  }
209  llvm_unreachable("Region not child of its parent operation");
210  };
211  regionArgOfOperation.isUnique = false;
212  regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
213  }
214 
215  // next parent
216  o = parent;
217  }
218 
219  // Determine whether a region-argument component is not needed
220  for (Component &c : components)
221  c.skipIf(c.isRegionArgOfOp && c.isUnique);
222 
223  // Find runs of nested loops and determine each loop's depth in the loop nest
224  size_t numSurroundingLoops = 0;
225  for (Component &c : llvm::reverse(components)) {
226  if (c.skip)
227  continue;
228 
229  // non-skipped multi-argument operands interrupt the loop nest
230  if (c.isRegionArgOfOp) {
231  numSurroundingLoops = 0;
232  continue;
233  }
234 
235  // Multiple loops in a region means each of them is the outermost loop of a
236  // new loop nest
237  if (!c.isUnique)
238  numSurroundingLoops = 0;
239 
240  c.getLoopDepth() = numSurroundingLoops;
241 
242  // Next loop is surrounded by one more loop
243  if (isa<CanonicalLoopOp>(c.getContainerOp()))
244  numSurroundingLoops += 1;
245  }
246 
247  // In loop nests, skip all but the innermost loop that contains the depth
248  // number
249  bool isLoopNest = false;
250  for (Component &c : components) {
251  if (c.skip || c.isRegionArgOfOp)
252  continue;
253 
254  if (!isLoopNest && c.getLoopDepth() >= 1) {
255  // Innermost loop of a loop nest of at least two loops
256  isLoopNest = true;
257  } else if (isLoopNest) {
258  // Non-innermost loop of a loop nest
259  c.skipIf(c.isUnique);
260 
261  // If there is no surrounding loop left, this must have been the outermost
262  // loop; leave loop-nest mode for the next iteration
263  if (c.getLoopDepth() == 0)
264  isLoopNest = false;
265  }
266  }
267 
268  // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
269  // we mark them as skipped only after computing loop nests)
270  for (Component &c : components)
271  c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272  !isa<CanonicalLoopOp>(c.getContainerOp()));
273 
274  // Components can be skipped if they are already disambiguated by their parent
275  // (or does not have a parent)
276  bool newRegion = true;
277  for (Component &c : llvm::reverse(components)) {
278  c.skipIf(newRegion && c.isUnique);
279 
280  // non-skipped components disambiguate unique children
281  if (!c.skip)
282  newRegion = true;
283 
284  // ...except canonical loops that need a suffix for each nest
285  if (!c.isRegionArgOfOp && c.getContainerOp())
286  newRegion = false;
287  }
288 
289  // Compile the nesting name string
290  SmallString<64> Name{prefix};
291  llvm::raw_svector_ostream NameOS(Name);
292  for (auto &c : llvm::reverse(components)) {
293  if (c.skip)
294  continue;
295 
296  if (c.isRegionArgOfOp)
297  NameOS << "_r" << c.getArgIdx();
298  else if (c.getLoopDepth() >= 1)
299  NameOS << "_d" << c.getLoopDepth();
300  else
301  NameOS << "_s" << c.getOpPos();
302  }
303 
304  return NameOS.str().str();
305 }
306 
307 void OpenMPDialect::initialize() {
308  addOperations<
309 #define GET_OP_LIST
310 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
311  >();
312  addAttributes<
313 #define GET_ATTRDEF_LIST
314 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
315  >();
316  addTypes<
317 #define GET_TYPEDEF_LIST
318 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
319  >();
320 
321  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322 
323  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
324  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
325  *getContext());
326 
327  // Attach default offload module interface to module op to access
328  // offload functionality through
329  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
330  *getContext());
331 
332  // Attach default declare target interfaces to operations which can be marked
333  // as declare target (Global Operations and Functions/Subroutines in dialects
334  // that Fortran (or other languages that lower to MLIR) translates too
335  mlir::LLVM::GlobalOp::attachInterface<
337  *getContext());
338  mlir::LLVM::LLVMFuncOp::attachInterface<
340  *getContext());
341  mlir::func::FuncOp::attachInterface<
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // Parser and printer for Allocate Clause
347 //===----------------------------------------------------------------------===//
348 
349 /// Parse an allocate clause with allocators and a list of operands with types.
350 ///
351 /// allocate-operand-list :: = allocate-operand |
352 /// allocator-operand `,` allocate-operand-list
353 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
354 /// ssa-id-and-type ::= ssa-id `:` type
355 static ParseResult parseAllocateAndAllocator(
356  OpAsmParser &parser,
358  SmallVectorImpl<Type> &allocateTypes,
360  SmallVectorImpl<Type> &allocatorTypes) {
361 
362  return parser.parseCommaSeparatedList([&]() {
364  Type type;
365  if (parser.parseOperand(operand) || parser.parseColonType(type))
366  return failure();
367  allocatorVars.push_back(operand);
368  allocatorTypes.push_back(type);
369  if (parser.parseArrow())
370  return failure();
371  if (parser.parseOperand(operand) || parser.parseColonType(type))
372  return failure();
373 
374  allocateVars.push_back(operand);
375  allocateTypes.push_back(type);
376  return success();
377  });
378 }
379 
380 /// Print allocate clause
382  OperandRange allocateVars,
383  TypeRange allocateTypes,
384  OperandRange allocatorVars,
385  TypeRange allocatorTypes) {
386  for (unsigned i = 0; i < allocateVars.size(); ++i) {
387  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
388  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
389  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
390  }
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // Parser and printer for a clause attribute (StringEnumAttr)
395 //===----------------------------------------------------------------------===//
396 
397 template <typename ClauseAttr>
398 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
399  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
400  StringRef enumStr;
401  SMLoc loc = parser.getCurrentLocation();
402  if (parser.parseKeyword(&enumStr))
403  return failure();
404  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
405  attr = ClauseAttr::get(parser.getContext(), *enumValue);
406  return success();
407  }
408  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
409 }
410 
411 template <typename ClauseAttr>
412 static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
413  p << stringifyEnum(attr.getValue());
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Parser and printer for Linear Clause
418 //===----------------------------------------------------------------------===//
419 
420 /// linear ::= `linear` `(` linear-list `)`
421 /// linear-list := linear-val | linear-val linear-list
422 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
423 static ParseResult parseLinearClause(
424  OpAsmParser &parser,
426  SmallVectorImpl<Type> &linearTypes,
428  return parser.parseCommaSeparatedList([&]() {
430  Type type;
432  if (parser.parseOperand(var) || parser.parseEqual() ||
433  parser.parseOperand(stepVar) || parser.parseColonType(type))
434  return failure();
435 
436  linearVars.push_back(var);
437  linearTypes.push_back(type);
438  linearStepVars.push_back(stepVar);
439  return success();
440  });
441 }
442 
443 /// Print Linear Clause
445  ValueRange linearVars, TypeRange linearTypes,
446  ValueRange linearStepVars) {
447  size_t linearVarsSize = linearVars.size();
448  for (unsigned i = 0; i < linearVarsSize; ++i) {
449  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
450  p << linearVars[i];
451  if (linearStepVars.size() > i)
452  p << " = " << linearStepVars[i];
453  p << " : " << linearVars[i].getType() << separator;
454  }
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // Verifier for Nontemporal Clause
459 //===----------------------------------------------------------------------===//
460 
461 static LogicalResult verifyNontemporalClause(Operation *op,
462  OperandRange nontemporalVars) {
463 
464  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
465  DenseSet<Value> nontemporalItems;
466  for (const auto &it : nontemporalVars)
467  if (!nontemporalItems.insert(it).second)
468  return op->emitOpError() << "nontemporal variable used more than once";
469 
470  return success();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // Parser, verifier and printer for Aligned Clause
475 //===----------------------------------------------------------------------===//
476 static LogicalResult verifyAlignedClause(Operation *op,
477  std::optional<ArrayAttr> alignments,
478  OperandRange alignedVars) {
479  // Check if number of alignment values equals to number of aligned variables
480  if (!alignedVars.empty()) {
481  if (!alignments || alignments->size() != alignedVars.size())
482  return op->emitOpError()
483  << "expected as many alignment values as aligned variables";
484  } else {
485  if (alignments)
486  return op->emitOpError() << "unexpected alignment values attribute";
487  return success();
488  }
489 
490  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
491  DenseSet<Value> alignedItems;
492  for (auto it : alignedVars)
493  if (!alignedItems.insert(it).second)
494  return op->emitOpError() << "aligned variable used more than once";
495 
496  if (!alignments)
497  return success();
498 
499  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
500  for (unsigned i = 0; i < (*alignments).size(); ++i) {
501  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
502  if (intAttr.getValue().sle(0))
503  return op->emitOpError() << "alignment should be greater than 0";
504  } else {
505  return op->emitOpError() << "expected integer alignment";
506  }
507  }
508 
509  return success();
510 }
511 
512 /// aligned ::= `aligned` `(` aligned-list `)`
513 /// aligned-list := aligned-val | aligned-val aligned-list
514 /// aligned-val := ssa-id-and-type `->` alignment
515 static ParseResult
518  SmallVectorImpl<Type> &alignedTypes,
519  ArrayAttr &alignmentsAttr) {
520  SmallVector<Attribute> alignmentVec;
521  if (failed(parser.parseCommaSeparatedList([&]() {
522  if (parser.parseOperand(alignedVars.emplace_back()) ||
523  parser.parseColonType(alignedTypes.emplace_back()) ||
524  parser.parseArrow() ||
525  parser.parseAttribute(alignmentVec.emplace_back())) {
526  return failure();
527  }
528  return success();
529  })))
530  return failure();
531  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
532  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
533  return success();
534 }
535 
536 /// Print Aligned Clause
538  ValueRange alignedVars, TypeRange alignedTypes,
539  std::optional<ArrayAttr> alignments) {
540  for (unsigned i = 0; i < alignedVars.size(); ++i) {
541  if (i != 0)
542  p << ", ";
543  p << alignedVars[i] << " : " << alignedVars[i].getType();
544  p << " -> " << (*alignments)[i];
545  }
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // Parser, printer and verifier for Schedule Clause
550 //===----------------------------------------------------------------------===//
551 
552 static ParseResult
554  SmallVectorImpl<SmallString<12>> &modifiers) {
555  if (modifiers.size() > 2)
556  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
557  for (const auto &mod : modifiers) {
558  // Translate the string. If it has no value, then it was not a valid
559  // modifier!
560  auto symbol = symbolizeScheduleModifier(mod);
561  if (!symbol)
562  return parser.emitError(parser.getNameLoc())
563  << " unknown modifier type: " << mod;
564  }
565 
566  // If we have one modifier that is "simd", then stick a "none" modiifer in
567  // index 0.
568  if (modifiers.size() == 1) {
569  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
570  modifiers.push_back(modifiers[0]);
571  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
572  }
573  } else if (modifiers.size() == 2) {
574  // If there are two modifier:
575  // First modifier should not be simd, second one should be simd
576  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
577  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
578  return parser.emitError(parser.getNameLoc())
579  << " incorrect modifier order";
580  }
581  return success();
582 }
583 
584 /// schedule ::= `schedule` `(` sched-list `)`
585 /// sched-list ::= sched-val | sched-val sched-list |
586 /// sched-val `,` sched-modifier
587 /// sched-val ::= sched-with-chunk | sched-wo-chunk
588 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
589 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
590 /// sched-wo-chunk ::= `auto` | `runtime`
591 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
592 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
593 static ParseResult
594 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
595  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
596  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
597  Type &chunkType) {
598  StringRef keyword;
599  if (parser.parseKeyword(&keyword))
600  return failure();
601  std::optional<mlir::omp::ClauseScheduleKind> schedule =
602  symbolizeClauseScheduleKind(keyword);
603  if (!schedule)
604  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
605 
606  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
607  switch (*schedule) {
608  case ClauseScheduleKind::Static:
609  case ClauseScheduleKind::Dynamic:
610  case ClauseScheduleKind::Guided:
611  if (succeeded(parser.parseOptionalEqual())) {
612  chunkSize = OpAsmParser::UnresolvedOperand{};
613  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
614  return failure();
615  } else {
616  chunkSize = std::nullopt;
617  }
618  break;
619  case ClauseScheduleKind::Auto:
621  chunkSize = std::nullopt;
622  }
623 
624  // If there is a comma, we have one or more modifiers..
625  SmallVector<SmallString<12>> modifiers;
626  while (succeeded(parser.parseOptionalComma())) {
627  StringRef mod;
628  if (parser.parseKeyword(&mod))
629  return failure();
630  modifiers.push_back(mod);
631  }
632 
633  if (verifyScheduleModifiers(parser, modifiers))
634  return failure();
635 
636  if (!modifiers.empty()) {
637  SMLoc loc = parser.getCurrentLocation();
638  if (std::optional<ScheduleModifier> mod =
639  symbolizeScheduleModifier(modifiers[0])) {
640  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
641  } else {
642  return parser.emitError(loc, "invalid schedule modifier");
643  }
644  // Only SIMD attribute is allowed here!
645  if (modifiers.size() > 1) {
646  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
647  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
648  }
649  }
650 
651  return success();
652 }
653 
654 /// Print schedule clause
656  ClauseScheduleKindAttr scheduleKind,
657  ScheduleModifierAttr scheduleMod,
658  UnitAttr scheduleSimd, Value scheduleChunk,
659  Type scheduleChunkType) {
660  p << stringifyClauseScheduleKind(scheduleKind.getValue());
661  if (scheduleChunk)
662  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
663  if (scheduleMod)
664  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
665  if (scheduleSimd)
666  p << ", simd";
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // Parser and printer for Order Clause
671 //===----------------------------------------------------------------------===//
672 
673 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
674 // order-modifier ::= reproducible | unconstrained
675 static ParseResult parseOrderClause(OpAsmParser &parser,
676  ClauseOrderKindAttr &order,
677  OrderModifierAttr &orderMod) {
678  StringRef enumStr;
679  SMLoc loc = parser.getCurrentLocation();
680  if (parser.parseKeyword(&enumStr))
681  return failure();
682  if (std::optional<OrderModifier> enumValue =
683  symbolizeOrderModifier(enumStr)) {
684  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
685  if (parser.parseOptionalColon())
686  return failure();
687  loc = parser.getCurrentLocation();
688  if (parser.parseKeyword(&enumStr))
689  return failure();
690  }
691  if (std::optional<ClauseOrderKind> enumValue =
692  symbolizeClauseOrderKind(enumStr)) {
693  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
694  return success();
695  }
696  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
697 }
698 
700  ClauseOrderKindAttr order,
701  OrderModifierAttr orderMod) {
702  if (orderMod)
703  p << stringifyOrderModifier(orderMod.getValue()) << ":";
704  if (order)
705  p << stringifyClauseOrderKind(order.getValue());
706 }
707 
708 template <typename ClauseTypeAttr, typename ClauseType>
709 static ParseResult
710 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
711  std::optional<OpAsmParser::UnresolvedOperand> &operand,
712  Type &operandType,
713  std::optional<ClauseType> (*symbolizeClause)(StringRef),
714  StringRef clauseName) {
715  StringRef enumStr;
716  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
717  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
718  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
719  if (parser.parseComma())
720  return failure();
721  } else {
722  return parser.emitError(parser.getCurrentLocation())
723  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
724  ;
725  }
726  }
727 
729  if (succeeded(parser.parseOperand(var))) {
730  operand = var;
731  } else {
732  return parser.emitError(parser.getCurrentLocation())
733  << "expected " << clauseName << " operand";
734  }
735 
736  if (operand.has_value()) {
737  if (parser.parseColonType(operandType))
738  return failure();
739  }
740 
741  return success();
742 }
743 
744 template <typename ClauseTypeAttr, typename ClauseType>
745 static void
747  ClauseTypeAttr prescriptiveness, Value operand,
748  mlir::Type operandType,
749  StringRef (*stringifyClauseType)(ClauseType)) {
750 
751  if (prescriptiveness)
752  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
753 
754  if (operand)
755  p << operand << ": " << operandType;
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // Parser and printer for grainsize Clause
760 //===----------------------------------------------------------------------===//
761 
762 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
763 static ParseResult
764 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
765  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
766  Type &grainsizeType) {
767  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
768  parser, grainsizeMod, grainsize, grainsizeType,
769  &symbolizeClauseGrainsizeType, "grainsize");
770 }
771 
773  ClauseGrainsizeTypeAttr grainsizeMod,
774  Value grainsize, mlir::Type grainsizeType) {
775  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
776  p, op, grainsizeMod, grainsize, grainsizeType,
777  &stringifyClauseGrainsizeType);
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // Parser and printer for num_tasks Clause
782 //===----------------------------------------------------------------------===//
783 
784 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
785 static ParseResult
786 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
787  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
788  Type &numTasksType) {
789  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
790  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
791  "num_tasks");
792 }
793 
795  ClauseNumTasksTypeAttr numTasksMod,
796  Value numTasks, mlir::Type numTasksType) {
797  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
798  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // Parsers for operations including clauses that define entry block arguments.
803 //===----------------------------------------------------------------------===//
804 
805 namespace {
806 struct MapParseArgs {
808  SmallVectorImpl<Type> &types;
810  SmallVectorImpl<Type> &types)
811  : vars(vars), types(types) {}
812 };
813 struct PrivateParseArgs {
816  ArrayAttr &syms;
817  UnitAttr &needsBarrier;
818  DenseI64ArrayAttr *mapIndices;
820  SmallVectorImpl<Type> &types, ArrayAttr &syms,
821  UnitAttr &needsBarrier,
822  DenseI64ArrayAttr *mapIndices = nullptr)
823  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
824  mapIndices(mapIndices) {}
825 };
826 
827 struct ReductionParseArgs {
829  SmallVectorImpl<Type> &types;
830  DenseBoolArrayAttr &byref;
831  ArrayAttr &syms;
832  ReductionModifierAttr *modifier;
833  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
835  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
836  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
837 };
838 
839 struct AllRegionParseArgs {
840  std::optional<MapParseArgs> hasDeviceAddrArgs;
841  std::optional<MapParseArgs> hostEvalArgs;
842  std::optional<ReductionParseArgs> inReductionArgs;
843  std::optional<MapParseArgs> mapArgs;
844  std::optional<PrivateParseArgs> privateArgs;
845  std::optional<ReductionParseArgs> reductionArgs;
846  std::optional<ReductionParseArgs> taskReductionArgs;
847  std::optional<MapParseArgs> useDeviceAddrArgs;
848  std::optional<MapParseArgs> useDevicePtrArgs;
849 };
850 } // namespace
851 
852 static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
853  return "private_barrier";
854 }
855 
856 static ParseResult parseClauseWithRegionArgs(
857  OpAsmParser &parser,
859  SmallVectorImpl<Type> &types,
860  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
861  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
862  DenseBoolArrayAttr *byref = nullptr,
863  ReductionModifierAttr *modifier = nullptr,
864  UnitAttr *needsBarrier = nullptr) {
865  SmallVector<SymbolRefAttr> symbolVec;
866  SmallVector<int64_t> mapIndicesVec;
867  SmallVector<bool> isByRefVec;
868  unsigned regionArgOffset = regionPrivateArgs.size();
869 
870  if (parser.parseLParen())
871  return failure();
872 
873  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
874  StringRef enumStr;
875  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
876  parser.parseComma())
877  return failure();
878  std::optional<ReductionModifier> enumValue =
879  symbolizeReductionModifier(enumStr);
880  if (!enumValue.has_value())
881  return failure();
882  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
883  if (!*modifier)
884  return failure();
885  }
886 
887  if (parser.parseCommaSeparatedList([&]() {
888  if (byref)
889  isByRefVec.push_back(
890  parser.parseOptionalKeyword("byref").succeeded());
891 
892  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
893  return failure();
894 
895  if (parser.parseOperand(operands.emplace_back()) ||
896  parser.parseArrow() ||
897  parser.parseArgument(regionPrivateArgs.emplace_back()))
898  return failure();
899 
900  if (mapIndices) {
901  if (parser.parseOptionalLSquare().succeeded()) {
902  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
903  parser.parseInteger(mapIndicesVec.emplace_back()) ||
904  parser.parseRSquare())
905  return failure();
906  } else {
907  mapIndicesVec.push_back(-1);
908  }
909  }
910 
911  return success();
912  }))
913  return failure();
914 
915  if (parser.parseColon())
916  return failure();
917 
918  if (parser.parseCommaSeparatedList([&]() {
919  if (parser.parseType(types.emplace_back()))
920  return failure();
921 
922  return success();
923  }))
924  return failure();
925 
926  if (operands.size() != types.size())
927  return failure();
928 
929  if (parser.parseRParen())
930  return failure();
931 
932  if (needsBarrier) {
934  .succeeded())
935  *needsBarrier = mlir::UnitAttr::get(parser.getContext());
936  }
937 
938  auto *argsBegin = regionPrivateArgs.begin();
939  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
940  argsBegin + regionArgOffset + types.size());
941  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
942  prv.type = type;
943  }
944 
945  if (symbols) {
946  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
947  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
948  }
949 
950  if (!mapIndicesVec.empty())
951  *mapIndices =
952  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
953 
954  if (byref)
955  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
956 
957  return success();
958 }
959 
960 static ParseResult parseBlockArgClause(
961  OpAsmParser &parser,
963  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
964  if (succeeded(parser.parseOptionalKeyword(keyword))) {
965  if (!mapArgs)
966  return failure();
967 
968  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
969  entryBlockArgs)))
970  return failure();
971  }
972  return success();
973 }
974 
975 static ParseResult parseBlockArgClause(
976  OpAsmParser &parser,
978  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
979  if (succeeded(parser.parseOptionalKeyword(keyword))) {
980  if (!privateArgs)
981  return failure();
982 
984  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
985  &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
986  /*modifier=*/nullptr, &privateArgs->needsBarrier)))
987  return failure();
988  }
989  return success();
990 }
991 
992 static ParseResult parseBlockArgClause(
993  OpAsmParser &parser,
995  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
996  if (succeeded(parser.parseOptionalKeyword(keyword))) {
997  if (!reductionArgs)
998  return failure();
1000  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1001  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1002  reductionArgs->modifier)))
1003  return failure();
1004  }
1005  return success();
1006 }
1007 
1008 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1009  AllRegionParseArgs args) {
1011 
1012  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1013  args.hasDeviceAddrArgs)))
1014  return parser.emitError(parser.getCurrentLocation())
1015  << "invalid `has_device_addr` format";
1016 
1017  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1018  args.hostEvalArgs)))
1019  return parser.emitError(parser.getCurrentLocation())
1020  << "invalid `host_eval` format";
1021 
1022  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1023  args.inReductionArgs)))
1024  return parser.emitError(parser.getCurrentLocation())
1025  << "invalid `in_reduction` format";
1026 
1027  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1028  args.mapArgs)))
1029  return parser.emitError(parser.getCurrentLocation())
1030  << "invalid `map_entries` format";
1031 
1032  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1033  args.privateArgs)))
1034  return parser.emitError(parser.getCurrentLocation())
1035  << "invalid `private` format";
1036 
1037  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1038  args.reductionArgs)))
1039  return parser.emitError(parser.getCurrentLocation())
1040  << "invalid `reduction` format";
1041 
1042  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1043  args.taskReductionArgs)))
1044  return parser.emitError(parser.getCurrentLocation())
1045  << "invalid `task_reduction` format";
1046 
1047  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1048  args.useDeviceAddrArgs)))
1049  return parser.emitError(parser.getCurrentLocation())
1050  << "invalid `use_device_addr` format";
1051 
1052  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1053  args.useDevicePtrArgs)))
1054  return parser.emitError(parser.getCurrentLocation())
1055  << "invalid `use_device_addr` format";
1056 
1057  return parser.parseRegion(region, entryBlockArgs);
1058 }
1059 
1060 // These parseXyz functions correspond to the custom<Xyz> definitions
1061 // in the .td file(s).
1062 static ParseResult parseTargetOpRegion(
1063  OpAsmParser &parser, Region &region,
1065  SmallVectorImpl<Type> &hasDeviceAddrTypes,
1067  SmallVectorImpl<Type> &hostEvalTypes,
1069  SmallVectorImpl<Type> &inReductionTypes,
1070  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1072  SmallVectorImpl<Type> &mapTypes,
1074  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1075  UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1076  AllRegionParseArgs args;
1077  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1078  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1079  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1080  inReductionByref, inReductionSyms);
1081  args.mapArgs.emplace(mapVars, mapTypes);
1082  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1083  privateNeedsBarrier, &privateMaps);
1084  return parseBlockArgRegion(parser, region, args);
1085 }
1086 
1088  OpAsmParser &parser, Region &region,
1090  SmallVectorImpl<Type> &inReductionTypes,
1091  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1093  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1094  UnitAttr &privateNeedsBarrier) {
1095  AllRegionParseArgs args;
1096  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1097  inReductionByref, inReductionSyms);
1098  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1099  privateNeedsBarrier);
1100  return parseBlockArgRegion(parser, region, args);
1101 }
1102 
1104  OpAsmParser &parser, Region &region,
1106  SmallVectorImpl<Type> &inReductionTypes,
1107  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1109  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1110  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1112  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1113  ArrayAttr &reductionSyms) {
1114  AllRegionParseArgs args;
1115  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1116  inReductionByref, inReductionSyms);
1117  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1118  privateNeedsBarrier);
1119  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1120  reductionSyms, &reductionMod);
1121  return parseBlockArgRegion(parser, region, args);
1122 }
1123 
1124 static ParseResult parsePrivateRegion(
1125  OpAsmParser &parser, Region &region,
1127  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1128  UnitAttr &privateNeedsBarrier) {
1129  AllRegionParseArgs args;
1130  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1131  privateNeedsBarrier);
1132  return parseBlockArgRegion(parser, region, args);
1133 }
1134 
1135 static ParseResult parsePrivateReductionRegion(
1136  OpAsmParser &parser, Region &region,
1138  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1139  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1141  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1142  ArrayAttr &reductionSyms) {
1143  AllRegionParseArgs args;
1144  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1145  privateNeedsBarrier);
1146  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1147  reductionSyms, &reductionMod);
1148  return parseBlockArgRegion(parser, region, args);
1149 }
1150 
1151 static ParseResult parseTaskReductionRegion(
1152  OpAsmParser &parser, Region &region,
1154  SmallVectorImpl<Type> &taskReductionTypes,
1155  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1156  AllRegionParseArgs args;
1157  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1158  taskReductionByref, taskReductionSyms);
1159  return parseBlockArgRegion(parser, region, args);
1160 }
1161 
1163  OpAsmParser &parser, Region &region,
1165  SmallVectorImpl<Type> &useDeviceAddrTypes,
1167  SmallVectorImpl<Type> &useDevicePtrTypes) {
1168  AllRegionParseArgs args;
1169  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1170  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1171  return parseBlockArgRegion(parser, region, args);
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // Printers for operations including clauses that define entry block arguments.
1176 //===----------------------------------------------------------------------===//
1177 
1178 namespace {
1179 struct MapPrintArgs {
1180  ValueRange vars;
1181  TypeRange types;
1182  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1183 };
1184 struct PrivatePrintArgs {
1185  ValueRange vars;
1186  TypeRange types;
1187  ArrayAttr syms;
1188  UnitAttr needsBarrier;
1189  DenseI64ArrayAttr mapIndices;
1190  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1191  UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1192  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1193  mapIndices(mapIndices) {}
1194 };
1195 struct ReductionPrintArgs {
1196  ValueRange vars;
1197  TypeRange types;
1198  DenseBoolArrayAttr byref;
1199  ArrayAttr syms;
1200  ReductionModifierAttr modifier;
1201  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1202  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1203  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1204 };
1205 struct AllRegionPrintArgs {
1206  std::optional<MapPrintArgs> hasDeviceAddrArgs;
1207  std::optional<MapPrintArgs> hostEvalArgs;
1208  std::optional<ReductionPrintArgs> inReductionArgs;
1209  std::optional<MapPrintArgs> mapArgs;
1210  std::optional<PrivatePrintArgs> privateArgs;
1211  std::optional<ReductionPrintArgs> reductionArgs;
1212  std::optional<ReductionPrintArgs> taskReductionArgs;
1213  std::optional<MapPrintArgs> useDeviceAddrArgs;
1214  std::optional<MapPrintArgs> useDevicePtrArgs;
1215 };
1216 } // namespace
1217 
1219  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1220  ValueRange argsSubrange, ValueRange operands, TypeRange types,
1221  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1222  DenseBoolArrayAttr byref = nullptr,
1223  ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1224  if (argsSubrange.empty())
1225  return;
1226 
1227  p << clauseName << "(";
1228 
1229  if (modifier)
1230  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1231 
1232  if (!symbols) {
1233  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1234  symbols = ArrayAttr::get(ctx, values);
1235  }
1236 
1237  if (!mapIndices) {
1238  llvm::SmallVector<int64_t> values(operands.size(), -1);
1239  mapIndices = DenseI64ArrayAttr::get(ctx, values);
1240  }
1241 
1242  if (!byref) {
1243  mlir::SmallVector<bool> values(operands.size(), false);
1244  byref = DenseBoolArrayAttr::get(ctx, values);
1245  }
1246 
1247  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1248  mapIndices.asArrayRef(),
1249  byref.asArrayRef()),
1250  p, [&p](auto t) {
1251  auto [op, arg, sym, map, isByRef] = t;
1252  if (isByRef)
1253  p << "byref ";
1254  if (sym)
1255  p << sym << " ";
1256 
1257  p << op << " -> " << arg;
1258 
1259  if (map != -1)
1260  p << " [map_idx=" << map << "]";
1261  });
1262  p << " : ";
1263  llvm::interleaveComma(types, p);
1264  p << ") ";
1265 
1266  if (needsBarrier)
1267  p << getPrivateNeedsBarrierSpelling() << " ";
1268 }
1269 
1271  StringRef clauseName, ValueRange argsSubrange,
1272  std::optional<MapPrintArgs> mapArgs) {
1273  if (mapArgs)
1274  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1275  mapArgs->types);
1276 }
1277 
1279  StringRef clauseName, ValueRange argsSubrange,
1280  std::optional<PrivatePrintArgs> privateArgs) {
1281  if (privateArgs)
1283  p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1284  privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1285  /*modifier=*/nullptr, privateArgs->needsBarrier);
1286 }
1287 
1288 static void
1289 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1290  ValueRange argsSubrange,
1291  std::optional<ReductionPrintArgs> reductionArgs) {
1292  if (reductionArgs)
1293  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1294  reductionArgs->vars, reductionArgs->types,
1295  reductionArgs->syms, /*mapIndices=*/nullptr,
1296  reductionArgs->byref, reductionArgs->modifier);
1297 }
1298 
1299 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1300  const AllRegionPrintArgs &args) {
1301  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1302  MLIRContext *ctx = op->getContext();
1303 
1304  printBlockArgClause(p, ctx, "has_device_addr",
1305  iface.getHasDeviceAddrBlockArgs(),
1306  args.hasDeviceAddrArgs);
1307  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1308  args.hostEvalArgs);
1309  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1310  args.inReductionArgs);
1311  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1312  args.mapArgs);
1313  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1314  args.privateArgs);
1315  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1316  args.reductionArgs);
1317  printBlockArgClause(p, ctx, "task_reduction",
1318  iface.getTaskReductionBlockArgs(),
1319  args.taskReductionArgs);
1320  printBlockArgClause(p, ctx, "use_device_addr",
1321  iface.getUseDeviceAddrBlockArgs(),
1322  args.useDeviceAddrArgs);
1323  printBlockArgClause(p, ctx, "use_device_ptr",
1324  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1325 
1326  p.printRegion(region, /*printEntryBlockArgs=*/false);
1327 }
1328 
1329 // These parseXyz functions correspond to the custom<Xyz> definitions
1330 // in the .td file(s).
1332  OpAsmPrinter &p, Operation *op, Region &region,
1333  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1334  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1335  ValueRange inReductionVars, TypeRange inReductionTypes,
1336  DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1337  ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1338  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1339  DenseI64ArrayAttr privateMaps) {
1340  AllRegionPrintArgs args;
1341  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1342  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1343  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1344  inReductionByref, inReductionSyms);
1345  args.mapArgs.emplace(mapVars, mapTypes);
1346  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1347  privateNeedsBarrier, privateMaps);
1348  printBlockArgRegion(p, op, region, args);
1349 }
1350 
1352  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1353  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1354  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1355  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1356  AllRegionPrintArgs args;
1357  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1358  inReductionByref, inReductionSyms);
1359  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1360  privateNeedsBarrier,
1361  /*mapIndices=*/nullptr);
1362  printBlockArgRegion(p, op, region, args);
1363 }
1364 
1366  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1367  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1368  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1369  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1370  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1371  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1372  ArrayAttr reductionSyms) {
1373  AllRegionPrintArgs args;
1374  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1375  inReductionByref, inReductionSyms);
1376  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1377  privateNeedsBarrier,
1378  /*mapIndices=*/nullptr);
1379  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1380  reductionSyms, reductionMod);
1381  printBlockArgRegion(p, op, region, args);
1382 }
1383 
1384 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1385  ValueRange privateVars, TypeRange privateTypes,
1386  ArrayAttr privateSyms,
1387  UnitAttr privateNeedsBarrier) {
1388  AllRegionPrintArgs args;
1389  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1390  privateNeedsBarrier,
1391  /*mapIndices=*/nullptr);
1392  printBlockArgRegion(p, op, region, args);
1393 }
1394 
1396  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1397  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1398  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1399  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1400  ArrayAttr reductionSyms) {
1401  AllRegionPrintArgs args;
1402  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1403  privateNeedsBarrier,
1404  /*mapIndices=*/nullptr);
1405  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1406  reductionSyms, reductionMod);
1407  printBlockArgRegion(p, op, region, args);
1408 }
1409 
1411  Region &region,
1412  ValueRange taskReductionVars,
1413  TypeRange taskReductionTypes,
1414  DenseBoolArrayAttr taskReductionByref,
1415  ArrayAttr taskReductionSyms) {
1416  AllRegionPrintArgs args;
1417  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1418  taskReductionByref, taskReductionSyms);
1419  printBlockArgRegion(p, op, region, args);
1420 }
1421 
1423  Region &region,
1424  ValueRange useDeviceAddrVars,
1425  TypeRange useDeviceAddrTypes,
1426  ValueRange useDevicePtrVars,
1427  TypeRange useDevicePtrTypes) {
1428  AllRegionPrintArgs args;
1429  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1430  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1431  printBlockArgRegion(p, op, region, args);
1432 }
1433 
1434 /// Verifies Reduction Clause
1435 static LogicalResult
1436 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1437  OperandRange reductionVars,
1438  std::optional<ArrayRef<bool>> reductionByref) {
1439  if (!reductionVars.empty()) {
1440  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1441  return op->emitOpError()
1442  << "expected as many reduction symbol references "
1443  "as reduction variables";
1444  if (reductionByref && reductionByref->size() != reductionVars.size())
1445  return op->emitError() << "expected as many reduction variable by "
1446  "reference attributes as reduction variables";
1447  } else {
1448  if (reductionSyms)
1449  return op->emitOpError() << "unexpected reduction symbol references";
1450  return success();
1451  }
1452 
1453  // TODO: The followings should be done in
1454  // SymbolUserOpInterface::verifySymbolUses.
1455  DenseSet<Value> accumulators;
1456  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1457  Value accum = std::get<0>(args);
1458 
1459  if (!accumulators.insert(accum).second)
1460  return op->emitOpError() << "accumulator variable used more than once";
1461 
1462  Type varType = accum.getType();
1463  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1464  auto decl =
1465  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1466  if (!decl)
1467  return op->emitOpError() << "expected symbol reference " << symbolRef
1468  << " to point to a reduction declaration";
1469 
1470  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1471  return op->emitOpError()
1472  << "expected accumulator (" << varType
1473  << ") to be the same type as reduction declaration ("
1474  << decl.getAccumulatorType() << ")";
1475  }
1476 
1477  return success();
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // Parser, printer and verifier for Copyprivate
1482 //===----------------------------------------------------------------------===//
1483 
1484 /// copyprivate-entry-list ::= copyprivate-entry
1485 /// | copyprivate-entry-list `,` copyprivate-entry
1486 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1487 static ParseResult parseCopyprivate(
1488  OpAsmParser &parser,
1490  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1492  if (failed(parser.parseCommaSeparatedList([&]() {
1493  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1494  parser.parseArrow() ||
1495  parser.parseAttribute(symsVec.emplace_back()) ||
1496  parser.parseColonType(copyprivateTypes.emplace_back()))
1497  return failure();
1498  return success();
1499  })))
1500  return failure();
1501  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1502  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1503  return success();
1504 }
1505 
1506 /// Print Copyprivate clause
1508  OperandRange copyprivateVars,
1509  TypeRange copyprivateTypes,
1510  std::optional<ArrayAttr> copyprivateSyms) {
1511  if (!copyprivateSyms.has_value())
1512  return;
1513  llvm::interleaveComma(
1514  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1515  [&](const auto &args) {
1516  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1517  << std::get<2>(args);
1518  });
1519 }
1520 
1521 /// Verifies CopyPrivate Clause
1522 static LogicalResult
1524  std::optional<ArrayAttr> copyprivateSyms) {
1525  size_t copyprivateSymsSize =
1526  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1527  if (copyprivateSymsSize != copyprivateVars.size())
1528  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1529  << copyprivateVars.size()
1530  << ") and functions (= " << copyprivateSymsSize
1531  << "), both must be equal";
1532  if (!copyprivateSyms.has_value())
1533  return success();
1534 
1535  for (auto copyprivateVarAndSym :
1536  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1537  auto symbolRef =
1538  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1539  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1540  funcOp;
1541  if (mlir::func::FuncOp mlirFuncOp =
1542  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1543  symbolRef))
1544  funcOp = mlirFuncOp;
1545  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1546  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1547  op, symbolRef))
1548  funcOp = llvmFuncOp;
1549 
1550  auto getNumArguments = [&] {
1551  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1552  };
1553 
1554  auto getArgumentType = [&](unsigned i) {
1555  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1556  *funcOp);
1557  };
1558 
1559  if (!funcOp)
1560  return op->emitOpError() << "expected symbol reference " << symbolRef
1561  << " to point to a copy function";
1562 
1563  if (getNumArguments() != 2)
1564  return op->emitOpError()
1565  << "expected copy function " << symbolRef << " to have 2 operands";
1566 
1567  Type argTy = getArgumentType(0);
1568  if (argTy != getArgumentType(1))
1569  return op->emitOpError() << "expected copy function " << symbolRef
1570  << " arguments to have the same type";
1571 
1572  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1573  if (argTy != varType)
1574  return op->emitOpError()
1575  << "expected copy function arguments' type (" << argTy
1576  << ") to be the same as copyprivate variable's type (" << varType
1577  << ")";
1578  }
1579 
1580  return success();
1581 }
1582 
1583 //===----------------------------------------------------------------------===//
1584 // Parser, printer and verifier for DependVarList
1585 //===----------------------------------------------------------------------===//
1586 
1587 /// depend-entry-list ::= depend-entry
1588 /// | depend-entry-list `,` depend-entry
1589 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1590 static ParseResult
1593  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1595  if (failed(parser.parseCommaSeparatedList([&]() {
1596  StringRef keyword;
1597  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1598  parser.parseOperand(dependVars.emplace_back()) ||
1599  parser.parseColonType(dependTypes.emplace_back()))
1600  return failure();
1601  if (std::optional<ClauseTaskDepend> keywordDepend =
1602  (symbolizeClauseTaskDepend(keyword)))
1603  kindsVec.emplace_back(
1604  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1605  else
1606  return failure();
1607  return success();
1608  })))
1609  return failure();
1610  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1611  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1612  return success();
1613 }
1614 
1615 /// Print Depend clause
1617  OperandRange dependVars, TypeRange dependTypes,
1618  std::optional<ArrayAttr> dependKinds) {
1619 
1620  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1621  if (i != 0)
1622  p << ", ";
1623  p << stringifyClauseTaskDepend(
1624  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1625  .getValue())
1626  << " -> " << dependVars[i] << " : " << dependTypes[i];
1627  }
1628 }
1629 
1630 /// Verifies Depend clause
1631 static LogicalResult verifyDependVarList(Operation *op,
1632  std::optional<ArrayAttr> dependKinds,
1633  OperandRange dependVars) {
1634  if (!dependVars.empty()) {
1635  if (!dependKinds || dependKinds->size() != dependVars.size())
1636  return op->emitOpError() << "expected as many depend values"
1637  " as depend variables";
1638  } else {
1639  if (dependKinds && !dependKinds->empty())
1640  return op->emitOpError() << "unexpected depend values";
1641  return success();
1642  }
1643 
1644  return success();
1645 }
1646 
1647 //===----------------------------------------------------------------------===//
1648 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1649 //===----------------------------------------------------------------------===//
1650 
1651 /// Parses a Synchronization Hint clause. The value of hint is an integer
1652 /// which is a combination of different hints from `omp_sync_hint_t`.
1653 ///
1654 /// hint-clause = `hint` `(` hint-value `)`
1655 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1656  IntegerAttr &hintAttr) {
1657  StringRef hintKeyword;
1658  int64_t hint = 0;
1659  if (succeeded(parser.parseOptionalKeyword("none"))) {
1660  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1661  return success();
1662  }
1663  auto parseKeyword = [&]() -> ParseResult {
1664  if (failed(parser.parseKeyword(&hintKeyword)))
1665  return failure();
1666  if (hintKeyword == "uncontended")
1667  hint |= 1;
1668  else if (hintKeyword == "contended")
1669  hint |= 2;
1670  else if (hintKeyword == "nonspeculative")
1671  hint |= 4;
1672  else if (hintKeyword == "speculative")
1673  hint |= 8;
1674  else
1675  return parser.emitError(parser.getCurrentLocation())
1676  << hintKeyword << " is not a valid hint";
1677  return success();
1678  };
1679  if (parser.parseCommaSeparatedList(parseKeyword))
1680  return failure();
1681  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1682  return success();
1683 }
1684 
1685 /// Prints a Synchronization Hint clause
1687  IntegerAttr hintAttr) {
1688  int64_t hint = hintAttr.getInt();
1689 
1690  if (hint == 0) {
1691  p << "none";
1692  return;
1693  }
1694 
1695  // Helper function to get n-th bit from the right end of `value`
1696  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1697 
1698  bool uncontended = bitn(hint, 0);
1699  bool contended = bitn(hint, 1);
1700  bool nonspeculative = bitn(hint, 2);
1701  bool speculative = bitn(hint, 3);
1702 
1703  SmallVector<StringRef> hints;
1704  if (uncontended)
1705  hints.push_back("uncontended");
1706  if (contended)
1707  hints.push_back("contended");
1708  if (nonspeculative)
1709  hints.push_back("nonspeculative");
1710  if (speculative)
1711  hints.push_back("speculative");
1712 
1713  llvm::interleaveComma(hints, p);
1714 }
1715 
1716 /// Verifies a synchronization hint clause
1717 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1718 
1719  // Helper function to get n-th bit from the right end of `value`
1720  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1721 
1722  bool uncontended = bitn(hint, 0);
1723  bool contended = bitn(hint, 1);
1724  bool nonspeculative = bitn(hint, 2);
1725  bool speculative = bitn(hint, 3);
1726 
1727  if (uncontended && contended)
1728  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1729  "omp_sync_hint_contended cannot be combined";
1730  if (nonspeculative && speculative)
1731  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1732  "omp_sync_hint_speculative cannot be combined.";
1733  return success();
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // Parser, printer and verifier for Target
1738 //===----------------------------------------------------------------------===//
1739 
1740 // Helper function to get bitwise AND of `value` and 'flag'
1741 static uint64_t mapTypeToBitFlag(uint64_t value,
1742  llvm::omp::OpenMPOffloadMappingFlags flag) {
1743  return value & llvm::to_underlying(flag);
1744 }
1745 
1746 /// Parses a map_entries map type from a string format back into its numeric
1747 /// value.
1748 ///
1749 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1750 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1751 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1752  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1753  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1754 
1755  // This simply verifies the correct keyword is read in, the
1756  // keyword itself is stored inside of the operation
1757  auto parseTypeAndMod = [&]() -> ParseResult {
1758  StringRef mapTypeMod;
1759  if (parser.parseKeyword(&mapTypeMod))
1760  return failure();
1761 
1762  if (mapTypeMod == "always")
1763  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1764 
1765  if (mapTypeMod == "implicit")
1766  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1767 
1768  if (mapTypeMod == "ompx_hold")
1769  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1770 
1771  if (mapTypeMod == "close")
1772  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1773 
1774  if (mapTypeMod == "present")
1775  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1776 
1777  if (mapTypeMod == "to")
1778  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1779 
1780  if (mapTypeMod == "from")
1781  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1782 
1783  if (mapTypeMod == "tofrom")
1784  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1785  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1786 
1787  if (mapTypeMod == "delete")
1788  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1789 
1790  if (mapTypeMod == "return_param")
1791  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1792 
1793  return success();
1794  };
1795 
1796  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1797  return failure();
1798 
1799  mapType = parser.getBuilder().getIntegerAttr(
1800  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1801  llvm::to_underlying(mapTypeBits));
1802 
1803  return success();
1804 }
1805 
1806 /// Prints a map_entries map type from its numeric value out into its string
1807 /// format.
1809  IntegerAttr mapType) {
1810  uint64_t mapTypeBits = mapType.getUInt();
1811 
1812  bool emitAllocRelease = true;
1814 
1815  // handling of always, close, present placed at the beginning of the string
1816  // to aid readability
1817  if (mapTypeToBitFlag(mapTypeBits,
1818  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1819  mapTypeStrs.push_back("always");
1820  if (mapTypeToBitFlag(mapTypeBits,
1821  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1822  mapTypeStrs.push_back("implicit");
1823  if (mapTypeToBitFlag(mapTypeBits,
1824  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1825  mapTypeStrs.push_back("ompx_hold");
1826  if (mapTypeToBitFlag(mapTypeBits,
1827  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1828  mapTypeStrs.push_back("close");
1829  if (mapTypeToBitFlag(mapTypeBits,
1830  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1831  mapTypeStrs.push_back("present");
1832 
1833  // special handling of to/from/tofrom/delete and release/alloc, release +
1834  // alloc are the abscense of one of the other flags, whereas tofrom requires
1835  // both the to and from flag to be set.
1836  bool to = mapTypeToBitFlag(mapTypeBits,
1837  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1838  bool from = mapTypeToBitFlag(
1839  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1840  if (to && from) {
1841  emitAllocRelease = false;
1842  mapTypeStrs.push_back("tofrom");
1843  } else if (from) {
1844  emitAllocRelease = false;
1845  mapTypeStrs.push_back("from");
1846  } else if (to) {
1847  emitAllocRelease = false;
1848  mapTypeStrs.push_back("to");
1849  }
1850  if (mapTypeToBitFlag(mapTypeBits,
1851  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1852  emitAllocRelease = false;
1853  mapTypeStrs.push_back("delete");
1854  }
1855  if (mapTypeToBitFlag(
1856  mapTypeBits,
1857  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1858  emitAllocRelease = false;
1859  mapTypeStrs.push_back("return_param");
1860  }
1861  if (emitAllocRelease)
1862  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1863 
1864  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1865  p << mapTypeStrs[i];
1866  if (i + 1 < mapTypeStrs.size()) {
1867  p << ", ";
1868  }
1869  }
1870 }
1871 
1872 static ParseResult parseMembersIndex(OpAsmParser &parser,
1873  ArrayAttr &membersIdx) {
1874  SmallVector<Attribute> values, memberIdxs;
1875 
1876  auto parseIndices = [&]() -> ParseResult {
1877  int64_t value;
1878  if (parser.parseInteger(value))
1879  return failure();
1880  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1881  APInt(64, value, /*isSigned=*/false)));
1882  return success();
1883  };
1884 
1885  do {
1886  if (failed(parser.parseLSquare()))
1887  return failure();
1888 
1889  if (parser.parseCommaSeparatedList(parseIndices))
1890  return failure();
1891 
1892  if (failed(parser.parseRSquare()))
1893  return failure();
1894 
1895  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1896  values.clear();
1897  } while (succeeded(parser.parseOptionalComma()));
1898 
1899  if (!memberIdxs.empty())
1900  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1901 
1902  return success();
1903 }
1904 
1905 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1906  ArrayAttr membersIdx) {
1907  if (!membersIdx)
1908  return;
1909 
1910  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1911  p << "[";
1912  auto memberIdx = cast<ArrayAttr>(v);
1913  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1914  p << cast<IntegerAttr>(v2).getInt();
1915  });
1916  p << "]";
1917  });
1918 }
1919 
1921  VariableCaptureKindAttr mapCaptureType) {
1922  std::string typeCapStr;
1923  llvm::raw_string_ostream typeCap(typeCapStr);
1924  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1925  typeCap << "ByRef";
1926  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1927  typeCap << "ByCopy";
1928  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1929  typeCap << "VLAType";
1930  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1931  typeCap << "This";
1932  p << typeCapStr;
1933 }
1934 
1935 static ParseResult parseCaptureType(OpAsmParser &parser,
1936  VariableCaptureKindAttr &mapCaptureType) {
1937  StringRef mapCaptureKey;
1938  if (parser.parseKeyword(&mapCaptureKey))
1939  return failure();
1940 
1941  if (mapCaptureKey == "This")
1942  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1943  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1944  if (mapCaptureKey == "ByRef")
1945  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1946  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1947  if (mapCaptureKey == "ByCopy")
1948  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1949  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1950  if (mapCaptureKey == "VLAType")
1951  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1952  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1953 
1954  return success();
1955 }
1956 
1957 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1960 
1961  for (auto mapOp : mapVars) {
1962  if (!mapOp.getDefiningOp())
1963  return emitError(op->getLoc(), "missing map operation");
1964 
1965  if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1966  uint64_t mapTypeBits = mapInfoOp.getMapType();
1967 
1968  bool to = mapTypeToBitFlag(
1969  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1970  bool from = mapTypeToBitFlag(
1971  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1972  bool del = mapTypeToBitFlag(
1973  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1974 
1975  bool always = mapTypeToBitFlag(
1976  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1977  bool close = mapTypeToBitFlag(
1978  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1979  bool implicit = mapTypeToBitFlag(
1980  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1981 
1982  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1983  return emitError(op->getLoc(),
1984  "to, from, tofrom and alloc map types are permitted");
1985 
1986  if (isa<TargetEnterDataOp>(op) && (from || del))
1987  return emitError(op->getLoc(), "to and alloc map types are permitted");
1988 
1989  if (isa<TargetExitDataOp>(op) && to)
1990  return emitError(op->getLoc(),
1991  "from, release and delete map types are permitted");
1992 
1993  if (isa<TargetUpdateOp>(op)) {
1994  if (del) {
1995  return emitError(op->getLoc(),
1996  "at least one of to or from map types must be "
1997  "specified, other map types are not permitted");
1998  }
1999 
2000  if (!to && !from) {
2001  return emitError(op->getLoc(),
2002  "at least one of to or from map types must be "
2003  "specified, other map types are not permitted");
2004  }
2005 
2006  auto updateVar = mapInfoOp.getVarPtr();
2007 
2008  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2009  (from && updateToVars.contains(updateVar))) {
2010  return emitError(
2011  op->getLoc(),
2012  "either to or from map types can be specified, not both");
2013  }
2014 
2015  if (always || close || implicit) {
2016  return emitError(
2017  op->getLoc(),
2018  "present, mapper and iterator map type modifiers are permitted");
2019  }
2020 
2021  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2022  }
2023  } else if (!isa<DeclareMapperInfoOp>(op)) {
2024  return emitError(op->getLoc(),
2025  "map argument is not a map entry operation");
2026  }
2027  }
2028 
2029  return success();
2030 }
2031 
2032 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2033  std::optional<DenseI64ArrayAttr> privateMapIndices =
2034  targetOp.getPrivateMapsAttr();
2035 
2036  // None of the private operands are mapped.
2037  if (!privateMapIndices.has_value() || !privateMapIndices.value())
2038  return success();
2039 
2040  OperandRange privateVars = targetOp.getPrivateVars();
2041 
2042  if (privateMapIndices.value().size() !=
2043  static_cast<int64_t>(privateVars.size()))
2044  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2045  "`private_maps` attribute mismatch");
2046 
2047  return success();
2048 }
2049 
2050 //===----------------------------------------------------------------------===//
2051 // MapInfoOp
2052 //===----------------------------------------------------------------------===//
2053 
2054 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2055  StringRef clauseName,
2056  OperandRange vars) {
2057  for (Value var : vars)
2058  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2059  return op->emitOpError()
2060  << "'" << clauseName
2061  << "' arguments must be defined by 'omp.map.info' ops";
2062  return success();
2063 }
2064 
2065 LogicalResult MapInfoOp::verify() {
2066  if (getMapperId() &&
2067  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
2068  *this, getMapperIdAttr())) {
2069  return emitError("invalid mapper id");
2070  }
2071 
2072  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2073  return failure();
2074 
2075  return success();
2076 }
2077 
2078 //===----------------------------------------------------------------------===//
2079 // TargetDataOp
2080 //===----------------------------------------------------------------------===//
2081 
2082 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2083  const TargetDataOperands &clauses) {
2084  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2085  clauses.mapVars, clauses.useDeviceAddrVars,
2086  clauses.useDevicePtrVars);
2087 }
2088 
2089 LogicalResult TargetDataOp::verify() {
2090  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2091  getUseDeviceAddrVars().empty()) {
2092  return ::emitError(this->getLoc(),
2093  "At least one of map, use_device_ptr_vars, or "
2094  "use_device_addr_vars operand must be present");
2095  }
2096 
2097  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2098  getUseDevicePtrVars())))
2099  return failure();
2100 
2101  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2102  getUseDeviceAddrVars())))
2103  return failure();
2104 
2105  return verifyMapClause(*this, getMapVars());
2106 }
2107 
2108 //===----------------------------------------------------------------------===//
2109 // TargetEnterDataOp
2110 //===----------------------------------------------------------------------===//
2111 
2112 void TargetEnterDataOp::build(
2113  OpBuilder &builder, OperationState &state,
2114  const TargetEnterExitUpdateDataOperands &clauses) {
2115  MLIRContext *ctx = builder.getContext();
2116  TargetEnterDataOp::build(builder, state,
2117  makeArrayAttr(ctx, clauses.dependKinds),
2118  clauses.dependVars, clauses.device, clauses.ifExpr,
2119  clauses.mapVars, clauses.nowait);
2120 }
2121 
2122 LogicalResult TargetEnterDataOp::verify() {
2123  LogicalResult verifyDependVars =
2124  verifyDependVarList(*this, getDependKinds(), getDependVars());
2125  return failed(verifyDependVars) ? verifyDependVars
2126  : verifyMapClause(*this, getMapVars());
2127 }
2128 
2129 //===----------------------------------------------------------------------===//
2130 // TargetExitDataOp
2131 //===----------------------------------------------------------------------===//
2132 
2133 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2134  const TargetEnterExitUpdateDataOperands &clauses) {
2135  MLIRContext *ctx = builder.getContext();
2136  TargetExitDataOp::build(builder, state,
2137  makeArrayAttr(ctx, clauses.dependKinds),
2138  clauses.dependVars, clauses.device, clauses.ifExpr,
2139  clauses.mapVars, clauses.nowait);
2140 }
2141 
2142 LogicalResult TargetExitDataOp::verify() {
2143  LogicalResult verifyDependVars =
2144  verifyDependVarList(*this, getDependKinds(), getDependVars());
2145  return failed(verifyDependVars) ? verifyDependVars
2146  : verifyMapClause(*this, getMapVars());
2147 }
2148 
2149 //===----------------------------------------------------------------------===//
2150 // TargetUpdateOp
2151 //===----------------------------------------------------------------------===//
2152 
2153 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2154  const TargetEnterExitUpdateDataOperands &clauses) {
2155  MLIRContext *ctx = builder.getContext();
2156  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2157  clauses.dependVars, clauses.device, clauses.ifExpr,
2158  clauses.mapVars, clauses.nowait);
2159 }
2160 
2161 LogicalResult TargetUpdateOp::verify() {
2162  LogicalResult verifyDependVars =
2163  verifyDependVarList(*this, getDependKinds(), getDependVars());
2164  return failed(verifyDependVars) ? verifyDependVars
2165  : verifyMapClause(*this, getMapVars());
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // TargetOp
2170 //===----------------------------------------------------------------------===//
2171 
2172 void TargetOp::build(OpBuilder &builder, OperationState &state,
2173  const TargetOperands &clauses) {
2174  MLIRContext *ctx = builder.getContext();
2175  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2176  // inReductionByref, inReductionSyms.
2177  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2178  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
2179  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2180  clauses.hostEvalVars, clauses.ifExpr,
2181  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2182  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
2183  clauses.mapVars, clauses.nowait, clauses.privateVars,
2184  makeArrayAttr(ctx, clauses.privateSyms),
2185  clauses.privateNeedsBarrier, clauses.threadLimit,
2186  /*private_maps=*/nullptr);
2187 }
2188 
2189 LogicalResult TargetOp::verify() {
2190  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
2191  return failure();
2192 
2193  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2194  getHasDeviceAddrVars())))
2195  return failure();
2196 
2197  if (failed(verifyMapClause(*this, getMapVars())))
2198  return failure();
2199 
2200  return verifyPrivateVarsMapping(*this);
2201 }
2202 
2203 LogicalResult TargetOp::verifyRegions() {
2204  auto teamsOps = getOps<TeamsOp>();
2205  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2206  return emitError("target containing multiple 'omp.teams' nested ops");
2207 
2208  // Check that host_eval values are only used in legal ways.
2209  Operation *capturedOp = getInnermostCapturedOmpOp();
2210  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2211  for (Value hostEvalArg :
2212  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2213  for (Operation *user : hostEvalArg.getUsers()) {
2214  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2215  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2216  teamsOp.getNumTeamsUpper(),
2217  teamsOp.getThreadLimit()},
2218  hostEvalArg))
2219  continue;
2220 
2221  return emitOpError() << "host_eval argument only legal as 'num_teams' "
2222  "and 'thread_limit' in 'omp.teams'";
2223  }
2224  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2225  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2226  parallelOp->isAncestor(capturedOp) &&
2227  hostEvalArg == parallelOp.getNumThreads())
2228  continue;
2229 
2230  return emitOpError()
2231  << "host_eval argument only legal as 'num_threads' in "
2232  "'omp.parallel' when representing target SPMD";
2233  }
2234  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2235  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2236  loopNestOp.getOperation() == capturedOp &&
2237  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2238  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2239  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2240  continue;
2241 
2242  return emitOpError() << "host_eval argument only legal as loop bounds "
2243  "and steps in 'omp.loop_nest' when trip count "
2244  "must be evaluated in the host";
2245  }
2246 
2247  return emitOpError() << "host_eval argument illegal use in '"
2248  << user->getName() << "' operation";
2249  }
2250  }
2251  return success();
2252 }
2253 
2254 static Operation *
2255 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2256  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2257  assert(rootOp && "expected valid operation");
2258 
2259  Dialect *ompDialect = rootOp->getDialect();
2260  Operation *capturedOp = nullptr;
2261  DominanceInfo domInfo;
2262 
2263  // Process in pre-order to check operations from outermost to innermost,
2264  // ensuring we only enter the region of an operation if it meets the criteria
2265  // for being captured. We stop the exploration of nested operations as soon as
2266  // we process a region holding no operations to be captured.
2267  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2268  if (op == rootOp)
2269  return WalkResult::advance();
2270 
2271  // Ignore operations of other dialects or omp operations with no regions,
2272  // because these will only be checked if they are siblings of an omp
2273  // operation that can potentially be captured.
2274  bool isOmpDialect = op->getDialect() == ompDialect;
2275  bool hasRegions = op->getNumRegions() > 0;
2276  if (!isOmpDialect || !hasRegions)
2277  return WalkResult::skip();
2278 
2279  // This operation cannot be captured if it can be executed more than once
2280  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2281  // be executed before all exits of the region (i.e. it doesn't dominate all
2282  // blocks with no successors reachable from the entry block).
2283  if (checkSingleMandatoryExec) {
2284  Region *parentRegion = op->getParentRegion();
2285  Block *parentBlock = op->getBlock();
2286 
2287  for (Block *successor : parentBlock->getSuccessors())
2288  if (successor->isReachable(parentBlock))
2289  return WalkResult::interrupt();
2290 
2291  for (Block &block : *parentRegion)
2292  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2293  !domInfo.dominates(parentBlock, &block))
2294  return WalkResult::interrupt();
2295  }
2296 
2297  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2298  // into nested operations.
2299  for (Operation &sibling : op->getParentRegion()->getOps())
2300  if (&sibling != op && !siblingAllowedFn(&sibling))
2301  return WalkResult::interrupt();
2302 
2303  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2304  // Otherwise, process the contents of this operation.
2305  capturedOp = op;
2306  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2307  : WalkResult::advance();
2308  });
2309 
2310  return capturedOp;
2311 }
2312 
2313 Operation *TargetOp::getInnermostCapturedOmpOp() {
2314  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2315 
2316  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2317  // effects, but don't include a memory write effect.
2318  return findCapturedOmpOp(
2319  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2320  if (!sibling)
2321  return false;
2322 
2323  if (ompDialect == sibling->getDialect())
2324  return sibling->hasTrait<OpTrait::IsTerminator>();
2325 
2326  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2328  effects;
2329  memOp.getEffects(effects);
2330  return !llvm::any_of(
2331  effects, [&](MemoryEffects::EffectInstance &effect) {
2332  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2333  isa<SideEffects::AutomaticAllocationScopeResource>(
2334  effect.getResource());
2335  });
2336  }
2337  return true;
2338  });
2339 }
2340 
2341 /// Check if we can promote SPMD kernel to No-Loop kernel.
2342 static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2343  WsloopOp *wsLoopOp) {
2344  // num_teams clause can break no-loop teams/threads assumption.
2345  if (teamsOp.getNumTeamsUpper())
2346  return false;
2347 
2348  // Reduction kernels are slower in no-loop mode.
2349  if (teamsOp.getNumReductionVars())
2350  return false;
2351  if (wsLoopOp->getNumReductionVars())
2352  return false;
2353 
2354  // Check if the user allows the promotion of kernels to no-loop mode.
2355  OffloadModuleInterface offloadMod =
2356  capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2357  if (!offloadMod)
2358  return false;
2359  auto ompFlags = offloadMod.getFlags();
2360  if (!ompFlags)
2361  return false;
2362  return ompFlags.getAssumeTeamsOversubscription() &&
2363  ompFlags.getAssumeThreadsOversubscription();
2364 }
2365 
2366 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2367  // A non-null captured op is only valid if it resides inside of a TargetOp
2368  // and is the result of calling getInnermostCapturedOmpOp() on it.
2369  TargetOp targetOp =
2370  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2371  assert((!capturedOp ||
2372  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2373  "unexpected captured op");
2374 
2375  // If it's not capturing a loop, it's a default target region.
2376  if (!isa_and_present<LoopNestOp>(capturedOp))
2377  return TargetRegionFlags::generic;
2378 
2379  // Get the innermost non-simd loop wrapper.
2380  SmallVector<LoopWrapperInterface> loopWrappers;
2381  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2382  assert(!loopWrappers.empty());
2383 
2384  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2385  if (isa<SimdOp>(innermostWrapper))
2386  innermostWrapper = std::next(innermostWrapper);
2387 
2388  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2389  if (numWrappers != 1 && numWrappers != 2)
2390  return TargetRegionFlags::generic;
2391 
2392  // Detect target-teams-distribute-parallel-wsloop[-simd].
2393  if (numWrappers == 2) {
2394  WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2395  if (!wsloopOp)
2396  return TargetRegionFlags::generic;
2397 
2398  innermostWrapper = std::next(innermostWrapper);
2399  if (!isa<DistributeOp>(innermostWrapper))
2400  return TargetRegionFlags::generic;
2401 
2402  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2403  if (!isa_and_present<ParallelOp>(parallelOp))
2404  return TargetRegionFlags::generic;
2405 
2406  TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2407  if (!teamsOp)
2408  return TargetRegionFlags::generic;
2409 
2410  if (teamsOp->getParentOp() == targetOp.getOperation()) {
2411  TargetRegionFlags result =
2412  TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2413  if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2414  result = result | TargetRegionFlags::no_loop;
2415  return result;
2416  }
2417  }
2418  // Detect target-teams-distribute[-simd] and target-teams-loop.
2419  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2420  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2421  if (!isa_and_present<TeamsOp>(teamsOp))
2422  return TargetRegionFlags::generic;
2423 
2424  if (teamsOp->getParentOp() != targetOp.getOperation())
2425  return TargetRegionFlags::generic;
2426 
2427  if (isa<LoopOp>(innermostWrapper))
2428  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2429 
2430  // Find single immediately nested captured omp.parallel and add spmd flag
2431  // (generic-spmd case).
2432  //
2433  // TODO: This shouldn't have to be done here, as it is too easy to break.
2434  // The openmp-opt pass should be updated to be able to promote kernels like
2435  // this from "Generic" to "Generic-SPMD". However, the use of the
2436  // `kmpc_distribute_static_loop` family of functions produced by the
2437  // OMPIRBuilder for these kernels prevents that from working.
2438  Dialect *ompDialect = targetOp->getDialect();
2439  Operation *nestedCapture = findCapturedOmpOp(
2440  capturedOp, /*checkSingleMandatoryExec=*/false,
2441  [&](Operation *sibling) {
2442  return sibling && (ompDialect != sibling->getDialect() ||
2443  sibling->hasTrait<OpTrait::IsTerminator>());
2444  });
2445 
2446  TargetRegionFlags result =
2447  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2448 
2449  if (!nestedCapture)
2450  return result;
2451 
2452  while (nestedCapture->getParentOp() != capturedOp)
2453  nestedCapture = nestedCapture->getParentOp();
2454 
2455  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2456  : result;
2457  }
2458  // Detect target-parallel-wsloop[-simd].
2459  else if (isa<WsloopOp>(innermostWrapper)) {
2460  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2461  if (!isa_and_present<ParallelOp>(parallelOp))
2462  return TargetRegionFlags::generic;
2463 
2464  if (parallelOp->getParentOp() == targetOp.getOperation())
2465  return TargetRegionFlags::spmd;
2466  }
2467 
2468  return TargetRegionFlags::generic;
2469 }
2470 
2471 //===----------------------------------------------------------------------===//
2472 // ParallelOp
2473 //===----------------------------------------------------------------------===//
2474 
2475 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2476  ArrayRef<NamedAttribute> attributes) {
2477  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2478  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2479  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2480  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2481  /*proc_bind_kind=*/nullptr,
2482  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2483  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2484  state.addAttributes(attributes);
2485 }
2486 
2487 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2488  const ParallelOperands &clauses) {
2489  MLIRContext *ctx = builder.getContext();
2490  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2491  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2492  makeArrayAttr(ctx, clauses.privateSyms),
2493  clauses.privateNeedsBarrier, clauses.procBindKind,
2494  clauses.reductionMod, clauses.reductionVars,
2495  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2496  makeArrayAttr(ctx, clauses.reductionSyms));
2497 }
2498 
2499 template <typename OpType>
2500 static LogicalResult verifyPrivateVarList(OpType &op) {
2501  auto privateVars = op.getPrivateVars();
2502  auto privateSyms = op.getPrivateSymsAttr();
2503 
2504  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2505  return success();
2506 
2507  auto numPrivateVars = privateVars.size();
2508  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2509 
2510  if (numPrivateVars != numPrivateSyms)
2511  return op.emitError() << "inconsistent number of private variables and "
2512  "privatizer op symbols, private vars: "
2513  << numPrivateVars
2514  << " vs. privatizer op symbols: " << numPrivateSyms;
2515 
2516  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2517  Type varType = std::get<0>(privateVarInfo).getType();
2518  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2519  PrivateClauseOp privatizerOp =
2520  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2521 
2522  if (privatizerOp == nullptr)
2523  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2524  << privateSym << "'";
2525 
2526  Type privatizerType = privatizerOp.getArgType();
2527 
2528  if (privatizerType && (varType != privatizerType))
2529  return op.emitError()
2530  << "type mismatch between a "
2531  << (privatizerOp.getDataSharingType() ==
2532  DataSharingClauseType::Private
2533  ? "private"
2534  : "firstprivate")
2535  << " variable and its privatizer op, var type: " << varType
2536  << " vs. privatizer op type: " << privatizerType;
2537  }
2538 
2539  return success();
2540 }
2541 
2542 LogicalResult ParallelOp::verify() {
2543  if (getAllocateVars().size() != getAllocatorVars().size())
2544  return emitError(
2545  "expected equal sizes for allocate and allocator variables");
2546 
2547  if (failed(verifyPrivateVarList(*this)))
2548  return failure();
2549 
2550  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2551  getReductionByref());
2552 }
2553 
2554 LogicalResult ParallelOp::verifyRegions() {
2555  auto distChildOps = getOps<DistributeOp>();
2556  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2557  if (numDistChildOps > 1)
2558  return emitError()
2559  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2560 
2561  if (numDistChildOps == 1) {
2562  if (!isComposite())
2563  return emitError()
2564  << "'omp.composite' attribute missing from composite operation";
2565 
2566  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2567  Operation &distributeOp = **distChildOps.begin();
2568  for (Operation &childOp : getOps()) {
2569  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2570  continue;
2571 
2572  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2573  return emitError() << "unexpected OpenMP operation inside of composite "
2574  "'omp.parallel': "
2575  << childOp.getName();
2576  }
2577  } else if (isComposite()) {
2578  return emitError()
2579  << "'omp.composite' attribute present in non-composite operation";
2580  }
2581  return success();
2582 }
2583 
2584 //===----------------------------------------------------------------------===//
2585 // TeamsOp
2586 //===----------------------------------------------------------------------===//
2587 
2589  while ((op = op->getParentOp()))
2590  if (isa<OpenMPDialect>(op->getDialect()))
2591  return false;
2592  return true;
2593 }
2594 
2595 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2596  const TeamsOperands &clauses) {
2597  MLIRContext *ctx = builder.getContext();
2598  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2599  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2600  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2601  /*private_vars=*/{}, /*private_syms=*/nullptr,
2602  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2603  clauses.reductionVars,
2604  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2605  makeArrayAttr(ctx, clauses.reductionSyms),
2606  clauses.threadLimit);
2607 }
2608 
2609 LogicalResult TeamsOp::verify() {
2610  // Check parent region
2611  // TODO If nested inside of a target region, also check that it does not
2612  // contain any statements, declarations or directives other than this
2613  // omp.teams construct. The issue is how to support the initialization of
2614  // this operation's own arguments (allow SSA values across omp.target?).
2615  Operation *op = getOperation();
2616  if (!isa<TargetOp>(op->getParentOp()) &&
2618  return emitError("expected to be nested inside of omp.target or not nested "
2619  "in any OpenMP dialect operations");
2620 
2621  // Check for num_teams clause restrictions
2622  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2623  auto numTeamsUpperBound = getNumTeamsUpper();
2624  if (!numTeamsUpperBound)
2625  return emitError("expected num_teams upper bound to be defined if the "
2626  "lower bound is defined");
2627  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2628  return emitError(
2629  "expected num_teams upper bound and lower bound to be the same type");
2630  }
2631 
2632  // Check for allocate clause restrictions
2633  if (getAllocateVars().size() != getAllocatorVars().size())
2634  return emitError(
2635  "expected equal sizes for allocate and allocator variables");
2636 
2637  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2638  getReductionByref());
2639 }
2640 
2641 //===----------------------------------------------------------------------===//
2642 // SectionOp
2643 //===----------------------------------------------------------------------===//
2644 
2645 OperandRange SectionOp::getPrivateVars() {
2646  return getParentOp().getPrivateVars();
2647 }
2648 
2649 OperandRange SectionOp::getReductionVars() {
2650  return getParentOp().getReductionVars();
2651 }
2652 
2653 //===----------------------------------------------------------------------===//
2654 // SectionsOp
2655 //===----------------------------------------------------------------------===//
2656 
2657 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2658  const SectionsOperands &clauses) {
2659  MLIRContext *ctx = builder.getContext();
2660  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2661  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2662  clauses.nowait, /*private_vars=*/{},
2663  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2664  clauses.reductionMod, clauses.reductionVars,
2665  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2666  makeArrayAttr(ctx, clauses.reductionSyms));
2667 }
2668 
2669 LogicalResult SectionsOp::verify() {
2670  if (getAllocateVars().size() != getAllocatorVars().size())
2671  return emitError(
2672  "expected equal sizes for allocate and allocator variables");
2673 
2674  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2675  getReductionByref());
2676 }
2677 
2678 LogicalResult SectionsOp::verifyRegions() {
2679  for (auto &inst : *getRegion().begin()) {
2680  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2681  return emitOpError()
2682  << "expected omp.section op or terminator op inside region";
2683  }
2684  }
2685 
2686  return success();
2687 }
2688 
2689 //===----------------------------------------------------------------------===//
2690 // SingleOp
2691 //===----------------------------------------------------------------------===//
2692 
2693 void SingleOp::build(OpBuilder &builder, OperationState &state,
2694  const SingleOperands &clauses) {
2695  MLIRContext *ctx = builder.getContext();
2696  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2697  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2698  clauses.copyprivateVars,
2699  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2700  /*private_vars=*/{}, /*private_syms=*/nullptr,
2701  /*private_needs_barrier=*/nullptr);
2702 }
2703 
2704 LogicalResult SingleOp::verify() {
2705  // Check for allocate clause restrictions
2706  if (getAllocateVars().size() != getAllocatorVars().size())
2707  return emitError(
2708  "expected equal sizes for allocate and allocator variables");
2709 
2710  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2711  getCopyprivateSyms());
2712 }
2713 
2714 //===----------------------------------------------------------------------===//
2715 // WorkshareOp
2716 //===----------------------------------------------------------------------===//
2717 
2718 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2719  const WorkshareOperands &clauses) {
2720  WorkshareOp::build(builder, state, clauses.nowait);
2721 }
2722 
2723 //===----------------------------------------------------------------------===//
2724 // WorkshareLoopWrapperOp
2725 //===----------------------------------------------------------------------===//
2726 
2727 LogicalResult WorkshareLoopWrapperOp::verify() {
2728  if (!(*this)->getParentOfType<WorkshareOp>())
2729  return emitOpError() << "must be nested in an omp.workshare";
2730  return success();
2731 }
2732 
2733 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2734  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2735  getNestedWrapper())
2736  return emitOpError() << "expected to be a standalone loop wrapper";
2737 
2738  return success();
2739 }
2740 
2741 //===----------------------------------------------------------------------===//
2742 // LoopWrapperInterface
2743 //===----------------------------------------------------------------------===//
2744 
2745 LogicalResult LoopWrapperInterface::verifyImpl() {
2746  Operation *op = this->getOperation();
2747  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2749  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2750  "and `SingleBlock` traits";
2751 
2752  if (op->getNumRegions() != 1)
2753  return emitOpError() << "loop wrapper does not contain exactly one region";
2754 
2755  Region &region = op->getRegion(0);
2756  if (range_size(region.getOps()) != 1)
2757  return emitOpError()
2758  << "loop wrapper does not contain exactly one nested op";
2759 
2760  Operation &firstOp = *region.op_begin();
2761  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2762  return emitOpError() << "nested in loop wrapper is not another loop "
2763  "wrapper or `omp.loop_nest`";
2764 
2765  return success();
2766 }
2767 
2768 //===----------------------------------------------------------------------===//
2769 // LoopOp
2770 //===----------------------------------------------------------------------===//
2771 
2772 void LoopOp::build(OpBuilder &builder, OperationState &state,
2773  const LoopOperands &clauses) {
2774  MLIRContext *ctx = builder.getContext();
2775 
2776  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2777  makeArrayAttr(ctx, clauses.privateSyms),
2778  clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2779  clauses.reductionMod, clauses.reductionVars,
2780  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2781  makeArrayAttr(ctx, clauses.reductionSyms));
2782 }
2783 
2784 LogicalResult LoopOp::verify() {
2785  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2786  getReductionByref());
2787 }
2788 
2789 LogicalResult LoopOp::verifyRegions() {
2790  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2791  getNestedWrapper())
2792  return emitOpError() << "expected to be a standalone loop wrapper";
2793 
2794  return success();
2795 }
2796 
2797 //===----------------------------------------------------------------------===//
2798 // WsloopOp
2799 //===----------------------------------------------------------------------===//
2800 
2801 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2802  ArrayRef<NamedAttribute> attributes) {
2803  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2804  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2805  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2806  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2807  /*private_needs_barrier=*/false,
2808  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2809  /*reduction_byref=*/nullptr,
2810  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2811  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2812  /*schedule_simd=*/false);
2813  state.addAttributes(attributes);
2814 }
2815 
2816 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2817  const WsloopOperands &clauses) {
2818  MLIRContext *ctx = builder.getContext();
2819  // TODO: Store clauses in op: allocateVars, allocatorVars
2820  WsloopOp::build(
2821  builder, state,
2822  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2823  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2824  clauses.ordered, clauses.privateVars,
2825  makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2826  clauses.reductionMod, clauses.reductionVars,
2827  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2828  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2829  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2830 }
2831 
2832 LogicalResult WsloopOp::verify() {
2833  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2834  getReductionByref());
2835 }
2836 
2837 LogicalResult WsloopOp::verifyRegions() {
2838  bool isCompositeChildLeaf =
2839  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2840 
2841  if (LoopWrapperInterface nested = getNestedWrapper()) {
2842  if (!isComposite())
2843  return emitError()
2844  << "'omp.composite' attribute missing from composite wrapper";
2845 
2846  // Check for the allowed leaf constructs that may appear in a composite
2847  // construct directly after DO/FOR.
2848  if (!isa<SimdOp>(nested))
2849  return emitError() << "only supported nested wrapper is 'omp.simd'";
2850 
2851  } else if (isComposite() && !isCompositeChildLeaf) {
2852  return emitError()
2853  << "'omp.composite' attribute present in non-composite wrapper";
2854  } else if (!isComposite() && isCompositeChildLeaf) {
2855  return emitError()
2856  << "'omp.composite' attribute missing from composite wrapper";
2857  }
2858 
2859  return success();
2860 }
2861 
2862 //===----------------------------------------------------------------------===//
2863 // Simd construct [2.9.3.1]
2864 //===----------------------------------------------------------------------===//
2865 
2866 void SimdOp::build(OpBuilder &builder, OperationState &state,
2867  const SimdOperands &clauses) {
2868  MLIRContext *ctx = builder.getContext();
2869  // TODO Store clauses in op: linearVars, linearStepVars
2870  SimdOp::build(builder, state, clauses.alignedVars,
2871  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2872  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2873  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2874  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2875  clauses.privateNeedsBarrier, clauses.reductionMod,
2876  clauses.reductionVars,
2877  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2878  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2879  clauses.simdlen);
2880 }
2881 
2882 LogicalResult SimdOp::verify() {
2883  if (getSimdlen().has_value() && getSafelen().has_value() &&
2884  getSimdlen().value() > getSafelen().value())
2885  return emitOpError()
2886  << "simdlen clause and safelen clause are both present, but the "
2887  "simdlen value is not less than or equal to safelen value";
2888 
2889  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2890  return failure();
2891 
2892  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2893  return failure();
2894 
2895  bool isCompositeChildLeaf =
2896  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2897 
2898  if (!isComposite() && isCompositeChildLeaf)
2899  return emitError()
2900  << "'omp.composite' attribute missing from composite wrapper";
2901 
2902  if (isComposite() && !isCompositeChildLeaf)
2903  return emitError()
2904  << "'omp.composite' attribute present in non-composite wrapper";
2905 
2906  // Firstprivate is not allowed for SIMD in the standard. Check that none of
2907  // the private decls are for firstprivate.
2908  std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2909  if (privateSyms) {
2910  for (const Attribute &sym : *privateSyms) {
2911  auto symRef = cast<SymbolRefAttr>(sym);
2912  omp::PrivateClauseOp privatizer =
2913  SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2914  getOperation(), symRef);
2915  if (!privatizer)
2916  return emitError() << "Cannot find privatizer '" << symRef << "'";
2917  if (privatizer.getDataSharingType() ==
2918  DataSharingClauseType::FirstPrivate)
2919  return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2920  }
2921  }
2922 
2923  return success();
2924 }
2925 
2926 LogicalResult SimdOp::verifyRegions() {
2927  if (getNestedWrapper())
2928  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2929 
2930  return success();
2931 }
2932 
2933 //===----------------------------------------------------------------------===//
2934 // Distribute construct [2.9.4.1]
2935 //===----------------------------------------------------------------------===//
2936 
2937 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2938  const DistributeOperands &clauses) {
2939  DistributeOp::build(builder, state, clauses.allocateVars,
2940  clauses.allocatorVars, clauses.distScheduleStatic,
2941  clauses.distScheduleChunkSize, clauses.order,
2942  clauses.orderMod, clauses.privateVars,
2943  makeArrayAttr(builder.getContext(), clauses.privateSyms),
2944  clauses.privateNeedsBarrier);
2945 }
2946 
2947 LogicalResult DistributeOp::verify() {
2948  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2949  return emitOpError() << "chunk size set without "
2950  "dist_schedule_static being present";
2951 
2952  if (getAllocateVars().size() != getAllocatorVars().size())
2953  return emitError(
2954  "expected equal sizes for allocate and allocator variables");
2955 
2956  return success();
2957 }
2958 
2959 LogicalResult DistributeOp::verifyRegions() {
2960  if (LoopWrapperInterface nested = getNestedWrapper()) {
2961  if (!isComposite())
2962  return emitError()
2963  << "'omp.composite' attribute missing from composite wrapper";
2964  // Check for the allowed leaf constructs that may appear in a composite
2965  // construct directly after DISTRIBUTE.
2966  if (isa<WsloopOp>(nested)) {
2967  Operation *parentOp = (*this)->getParentOp();
2968  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2969  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2970  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2971  "when a composite 'omp.parallel' is the direct "
2972  "parent";
2973  }
2974  } else if (!isa<SimdOp>(nested))
2975  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2976  "'omp.wsloop'";
2977  } else if (isComposite()) {
2978  return emitError()
2979  << "'omp.composite' attribute present in non-composite wrapper";
2980  }
2981 
2982  return success();
2983 }
2984 
2985 //===----------------------------------------------------------------------===//
2986 // DeclareMapperOp / DeclareMapperInfoOp
2987 //===----------------------------------------------------------------------===//
2988 
2989 LogicalResult DeclareMapperInfoOp::verify() {
2990  return verifyMapClause(*this, getMapVars());
2991 }
2992 
2993 LogicalResult DeclareMapperOp::verifyRegions() {
2994  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2995  getRegion().getBlocks().front().getTerminator()))
2996  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2997 
2998  return success();
2999 }
3000 
3001 //===----------------------------------------------------------------------===//
3002 // DeclareReductionOp
3003 //===----------------------------------------------------------------------===//
3004 
3005 LogicalResult DeclareReductionOp::verifyRegions() {
3006  if (!getAllocRegion().empty()) {
3007  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3008  if (yieldOp.getResults().size() != 1 ||
3009  yieldOp.getResults().getTypes()[0] != getType())
3010  return emitOpError() << "expects alloc region to yield a value "
3011  "of the reduction type";
3012  }
3013  }
3014 
3015  if (getInitializerRegion().empty())
3016  return emitOpError() << "expects non-empty initializer region";
3017  Block &initializerEntryBlock = getInitializerRegion().front();
3018 
3019  if (initializerEntryBlock.getNumArguments() == 1) {
3020  if (!getAllocRegion().empty())
3021  return emitOpError() << "expects two arguments to the initializer region "
3022  "when an allocation region is used";
3023  } else if (initializerEntryBlock.getNumArguments() == 2) {
3024  if (getAllocRegion().empty())
3025  return emitOpError() << "expects one argument to the initializer region "
3026  "when no allocation region is used";
3027  } else {
3028  return emitOpError()
3029  << "expects one or two arguments to the initializer region";
3030  }
3031 
3032  for (mlir::Value arg : initializerEntryBlock.getArguments())
3033  if (arg.getType() != getType())
3034  return emitOpError() << "expects initializer region argument to match "
3035  "the reduction type";
3036 
3037  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3038  if (yieldOp.getResults().size() != 1 ||
3039  yieldOp.getResults().getTypes()[0] != getType())
3040  return emitOpError() << "expects initializer region to yield a value "
3041  "of the reduction type";
3042  }
3043 
3044  if (getReductionRegion().empty())
3045  return emitOpError() << "expects non-empty reduction region";
3046  Block &reductionEntryBlock = getReductionRegion().front();
3047  if (reductionEntryBlock.getNumArguments() != 2 ||
3048  reductionEntryBlock.getArgumentTypes()[0] !=
3049  reductionEntryBlock.getArgumentTypes()[1] ||
3050  reductionEntryBlock.getArgumentTypes()[0] != getType())
3051  return emitOpError() << "expects reduction region with two arguments of "
3052  "the reduction type";
3053  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3054  if (yieldOp.getResults().size() != 1 ||
3055  yieldOp.getResults().getTypes()[0] != getType())
3056  return emitOpError() << "expects reduction region to yield a value "
3057  "of the reduction type";
3058  }
3059 
3060  if (!getAtomicReductionRegion().empty()) {
3061  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3062  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3063  atomicReductionEntryBlock.getArgumentTypes()[0] !=
3064  atomicReductionEntryBlock.getArgumentTypes()[1])
3065  return emitOpError() << "expects atomic reduction region with two "
3066  "arguments of the same type";
3067  auto ptrType = llvm::dyn_cast<PointerLikeType>(
3068  atomicReductionEntryBlock.getArgumentTypes()[0]);
3069  if (!ptrType ||
3070  (ptrType.getElementType() && ptrType.getElementType() != getType()))
3071  return emitOpError() << "expects atomic reduction region arguments to "
3072  "be accumulators containing the reduction type";
3073  }
3074 
3075  if (getCleanupRegion().empty())
3076  return success();
3077  Block &cleanupEntryBlock = getCleanupRegion().front();
3078  if (cleanupEntryBlock.getNumArguments() != 1 ||
3079  cleanupEntryBlock.getArgument(0).getType() != getType())
3080  return emitOpError() << "expects cleanup region with one argument "
3081  "of the reduction type";
3082 
3083  return success();
3084 }
3085 
3086 //===----------------------------------------------------------------------===//
3087 // TaskOp
3088 //===----------------------------------------------------------------------===//
3089 
3090 void TaskOp::build(OpBuilder &builder, OperationState &state,
3091  const TaskOperands &clauses) {
3092  MLIRContext *ctx = builder.getContext();
3093  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3094  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3095  clauses.final, clauses.ifExpr, clauses.inReductionVars,
3096  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3097  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3098  clauses.priority, /*private_vars=*/clauses.privateVars,
3099  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3100  clauses.privateNeedsBarrier, clauses.untied,
3101  clauses.eventHandle);
3102 }
3103 
3104 LogicalResult TaskOp::verify() {
3105  LogicalResult verifyDependVars =
3106  verifyDependVarList(*this, getDependKinds(), getDependVars());
3107  return failed(verifyDependVars)
3108  ? verifyDependVars
3109  : verifyReductionVarList(*this, getInReductionSyms(),
3110  getInReductionVars(),
3111  getInReductionByref());
3112 }
3113 
3114 //===----------------------------------------------------------------------===//
3115 // TaskgroupOp
3116 //===----------------------------------------------------------------------===//
3117 
3118 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3119  const TaskgroupOperands &clauses) {
3120  MLIRContext *ctx = builder.getContext();
3121  TaskgroupOp::build(builder, state, clauses.allocateVars,
3122  clauses.allocatorVars, clauses.taskReductionVars,
3123  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3124  makeArrayAttr(ctx, clauses.taskReductionSyms));
3125 }
3126 
3127 LogicalResult TaskgroupOp::verify() {
3128  return verifyReductionVarList(*this, getTaskReductionSyms(),
3129  getTaskReductionVars(),
3130  getTaskReductionByref());
3131 }
3132 
3133 //===----------------------------------------------------------------------===//
3134 // TaskloopOp
3135 //===----------------------------------------------------------------------===//
3136 
3137 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3138  const TaskloopOperands &clauses) {
3139  MLIRContext *ctx = builder.getContext();
3140  TaskloopOp::build(
3141  builder, state, clauses.allocateVars, clauses.allocatorVars,
3142  clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3143  clauses.inReductionVars,
3144  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3145  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3146  clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3147  /*private_vars=*/clauses.privateVars,
3148  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3149  clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3150  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3151  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3152 }
3153 
3154 LogicalResult TaskloopOp::verify() {
3155  if (getAllocateVars().size() != getAllocatorVars().size())
3156  return emitError(
3157  "expected equal sizes for allocate and allocator variables");
3158  if (failed(verifyReductionVarList(*this, getReductionSyms(),
3159  getReductionVars(), getReductionByref())) ||
3160  failed(verifyReductionVarList(*this, getInReductionSyms(),
3161  getInReductionVars(),
3162  getInReductionByref())))
3163  return failure();
3164 
3165  if (!getReductionVars().empty() && getNogroup())
3166  return emitError("if a reduction clause is present on the taskloop "
3167  "directive, the nogroup clause must not be specified");
3168  for (auto var : getReductionVars()) {
3169  if (llvm::is_contained(getInReductionVars(), var))
3170  return emitError("the same list item cannot appear in both a reduction "
3171  "and an in_reduction clause");
3172  }
3173 
3174  if (getGrainsize() && getNumTasks()) {
3175  return emitError(
3176  "the grainsize clause and num_tasks clause are mutually exclusive and "
3177  "may not appear on the same taskloop directive");
3178  }
3179 
3180  return success();
3181 }
3182 
3183 LogicalResult TaskloopOp::verifyRegions() {
3184  if (LoopWrapperInterface nested = getNestedWrapper()) {
3185  if (!isComposite())
3186  return emitError()
3187  << "'omp.composite' attribute missing from composite wrapper";
3188 
3189  // Check for the allowed leaf constructs that may appear in a composite
3190  // construct directly after TASKLOOP.
3191  if (!isa<SimdOp>(nested))
3192  return emitError() << "only supported nested wrapper is 'omp.simd'";
3193  } else if (isComposite()) {
3194  return emitError()
3195  << "'omp.composite' attribute present in non-composite wrapper";
3196  }
3197 
3198  return success();
3199 }
3200 
3201 //===----------------------------------------------------------------------===//
3202 // LoopNestOp
3203 //===----------------------------------------------------------------------===//
3204 
3205 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3206  // Parse an opening `(` followed by induction variables followed by `)`
3209  Type loopVarType;
3210  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
3211  parser.parseColonType(loopVarType) ||
3212  // Parse loop bounds.
3213  parser.parseEqual() ||
3214  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3215  parser.parseKeyword("to") ||
3216  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3217  return failure();
3218 
3219  for (auto &iv : ivs)
3220  iv.type = loopVarType;
3221 
3222  auto *ctx = parser.getBuilder().getContext();
3223  // Parse "inclusive" flag.
3224  if (succeeded(parser.parseOptionalKeyword("inclusive")))
3225  result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3226 
3227  // Parse step values.
3229  if (parser.parseKeyword("step") ||
3230  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3231  return failure();
3232 
3233  // Parse collapse
3234  int64_t value = 0;
3235  if (!parser.parseOptionalKeyword("collapse") &&
3236  (parser.parseLParen() || parser.parseInteger(value) ||
3237  parser.parseRParen()))
3238  return failure();
3239  if (value > 1)
3240  result.addAttribute(
3241  "collapse_num_loops",
3242  IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3243 
3244  // Parse tiles
3245  SmallVector<int64_t> tiles;
3246  auto parseTiles = [&]() -> ParseResult {
3247  int64_t tile;
3248  if (parser.parseInteger(tile))
3249  return failure();
3250  tiles.push_back(tile);
3251  return success();
3252  };
3253 
3254  if (!parser.parseOptionalKeyword("tiles") &&
3255  (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3256  parser.parseRParen()))
3257  return failure();
3258 
3259  if (tiles.size() > 0)
3260  result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3261 
3262  // Parse the body.
3263  Region *region = result.addRegion();
3264  if (parser.parseRegion(*region, ivs))
3265  return failure();
3266 
3267  // Resolve operands.
3268  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3269  parser.resolveOperands(ubs, loopVarType, result.operands) ||
3270  parser.resolveOperands(steps, loopVarType, result.operands))
3271  return failure();
3272 
3273  // Parse the optional attribute list.
3274  return parser.parseOptionalAttrDict(result.attributes);
3275 }
3276 
3278  Region &region = getRegion();
3279  auto args = region.getArguments();
3280  p << " (" << args << ") : " << args[0].getType() << " = ("
3281  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3282  if (getLoopInclusive())
3283  p << "inclusive ";
3284  p << "step (" << getLoopSteps() << ") ";
3285  if (int64_t numCollapse = getCollapseNumLoops())
3286  if (numCollapse > 1)
3287  p << "collapse(" << numCollapse << ") ";
3288 
3289  if (const auto tiles = getTileSizes())
3290  p << "tiles(" << tiles.value() << ") ";
3291 
3292  p.printRegion(region, /*printEntryBlockArgs=*/false);
3293 }
3294 
3295 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3296  const LoopNestOperands &clauses) {
3297  MLIRContext *ctx = builder.getContext();
3298  LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3299  clauses.loopLowerBounds, clauses.loopUpperBounds,
3300  clauses.loopSteps, clauses.loopInclusive,
3301  makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3302 }
3303 
3304 LogicalResult LoopNestOp::verify() {
3305  if (getLoopLowerBounds().empty())
3306  return emitOpError() << "must represent at least one loop";
3307 
3308  if (getLoopLowerBounds().size() != getIVs().size())
3309  return emitOpError() << "number of range arguments and IVs do not match";
3310 
3311  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3312  if (lb.getType() != iv.getType())
3313  return emitOpError()
3314  << "range argument type does not match corresponding IV type";
3315  }
3316 
3317  uint64_t numIVs = getIVs().size();
3318 
3319  if (const auto &numCollapse = getCollapseNumLoops())
3320  if (numCollapse > numIVs)
3321  return emitOpError()
3322  << "collapse value is larger than the number of loops";
3323 
3324  if (const auto &tiles = getTileSizes())
3325  if (tiles.value().size() > numIVs)
3326  return emitOpError() << "too few canonical loops for tile dimensions";
3327 
3328  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3329  return emitOpError() << "expects parent op to be a loop wrapper";
3330 
3331  return success();
3332 }
3333 
3334 void LoopNestOp::gatherWrappers(
3336  Operation *parent = (*this)->getParentOp();
3337  while (auto wrapper =
3338  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3339  wrappers.push_back(wrapper);
3340  parent = parent->getParentOp();
3341  }
3342 }
3343 
3344 //===----------------------------------------------------------------------===//
3345 // OpenMP canonical loop handling
3346 //===----------------------------------------------------------------------===//
3347 
3348 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3350 
3351  // Defining a CLI for a generated loop is optional; if there is none then
3352  // there is no followup-tranformation
3353  if (!cli)
3354  return {{}, nullptr, nullptr};
3355 
3356  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3357  "Unexpected type of cli");
3358 
3359  NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3360  OpOperand *gen = nullptr;
3361  OpOperand *cons = nullptr;
3362  for (OpOperand &use : cli.getUses()) {
3363  auto op = cast<LoopTransformationInterface>(use.getOwner());
3364 
3365  unsigned opnum = use.getOperandNumber();
3366  if (op.isGeneratee(opnum)) {
3367  assert(!gen && "Each CLI may have at most one def");
3368  gen = &use;
3369  } else if (op.isApplyee(opnum)) {
3370  assert(!cons && "Each CLI may have at most one consumer");
3371  cons = &use;
3372  } else {
3373  llvm_unreachable("Unexpected operand for a CLI");
3374  }
3375  }
3376 
3377  return {create, gen, cons};
3378 }
3379 
3380 void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3381  ::mlir::OperationState &odsState) {
3382  odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3383 }
3384 
3385 void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3386  Value result = getResult();
3387  auto [newCli, gen, cons] = decodeCli(result);
3388 
3389  // Structured binding `gen` cannot be captured in lambdas before C++20
3390  OpOperand *generator = gen;
3391 
3392  // Derive the CLI variable name from its generator:
3393  // * "canonloop" for omp.canonical_loop
3394  // * custom name for loop transformation generatees
3395  // * "cli" as fallback if no generator
3396  // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3397  // at that level
3398  // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3399  // the index of that region
3400  std::string cliName{"cli"};
3401  if (gen) {
3402  cliName =
3403  TypeSwitch<Operation *, std::string>(gen->getOwner())
3404  .Case([&](CanonicalLoopOp op) {
3405  return generateLoopNestingName("canonloop", op);
3406  })
3407  .Case([&](UnrollHeuristicOp op) -> std::string {
3408  llvm_unreachable("heuristic unrolling does not generate a loop");
3409  })
3410  .Case([&](TileOp op) -> std::string {
3411  auto [generateesFirst, generateesCount] =
3412  op.getGenerateesODSOperandIndexAndLength();
3413  unsigned firstGrid = generateesFirst;
3414  unsigned firstIntratile = generateesFirst + generateesCount / 2;
3415  unsigned end = generateesFirst + generateesCount;
3416  unsigned opnum = generator->getOperandNumber();
3417  // In the OpenMP apply and looprange clauses, indices are 1-based
3418  if (firstGrid <= opnum && opnum < firstIntratile) {
3419  unsigned gridnum = opnum - firstGrid + 1;
3420  return ("grid" + Twine(gridnum)).str();
3421  }
3422  if (firstIntratile <= opnum && opnum < end) {
3423  unsigned intratilenum = opnum - firstIntratile + 1;
3424  return ("intratile" + Twine(intratilenum)).str();
3425  }
3426  llvm_unreachable("Unexpected generatee argument");
3427  })
3428  .DefaultUnreachable("TODO: Custom name for this operation");
3429  }
3430 
3431  setNameFn(result, cliName);
3432 }
3433 
3434 LogicalResult NewCliOp::verify() {
3435  Value cli = getResult();
3436 
3437  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3438  "Unexpected type of cli");
3439 
3440  // Check that the CLI is used in at most generator and one consumer
3441  OpOperand *gen = nullptr;
3442  OpOperand *cons = nullptr;
3443  for (mlir::OpOperand &use : cli.getUses()) {
3444  auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3445 
3446  unsigned opnum = use.getOperandNumber();
3447  if (op.isGeneratee(opnum)) {
3448  if (gen) {
3449  InFlightDiagnostic error =
3450  emitOpError("CLI must have at most one generator");
3451  error.attachNote(gen->getOwner()->getLoc())
3452  .append("first generator here:");
3453  error.attachNote(use.getOwner()->getLoc())
3454  .append("second generator here:");
3455  return error;
3456  }
3457 
3458  gen = &use;
3459  } else if (op.isApplyee(opnum)) {
3460  if (cons) {
3461  InFlightDiagnostic error =
3462  emitOpError("CLI must have at most one consumer");
3463  error.attachNote(cons->getOwner()->getLoc())
3464  .append("first consumer here:")
3465  .appendOp(*cons->getOwner(),
3466  OpPrintingFlags().printGenericOpForm());
3467  error.attachNote(use.getOwner()->getLoc())
3468  .append("second consumer here:")
3469  .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3470  return error;
3471  }
3472 
3473  cons = &use;
3474  } else {
3475  llvm_unreachable("Unexpected operand for a CLI");
3476  }
3477  }
3478 
3479  // If the CLI is source of a transformation, it must have a generator
3480  if (cons && !gen) {
3481  InFlightDiagnostic error = emitOpError("CLI has no generator");
3482  error.attachNote(cons->getOwner()->getLoc())
3483  .append("see consumer here: ")
3484  .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3485  return error;
3486  }
3487 
3488  return success();
3489 }
3490 
3491 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3492  Value tripCount) {
3493  odsState.addOperands(tripCount);
3494  odsState.addOperands(Value());
3495  (void)odsState.addRegion();
3496 }
3497 
3498 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3499  Value tripCount, ::mlir::Value cli) {
3500  odsState.addOperands(tripCount);
3501  odsState.addOperands(cli);
3502  (void)odsState.addRegion();
3503 }
3504 
3505 void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3506  setNameFn(&getRegion().front(), "body_entry");
3507 }
3508 
3509 void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3510  OpAsmSetValueNameFn setNameFn) {
3511  std::string ivName = generateLoopNestingName("iv", *this);
3512  setNameFn(region.getArgument(0), ivName);
3513 }
3514 
3516  if (getCli())
3517  p << '(' << getCli() << ')';
3518  p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3519  << " in range(" << getTripCount() << ") ";
3520 
3521  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3522  /*printBlockTerminators=*/true);
3523 
3524  p.printOptionalAttrDict((*this)->getAttrs());
3525 }
3526 
3527 mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3528  ::mlir::OperationState &result) {
3529  CanonicalLoopInfoType cliType =
3531 
3532  // Parse (optional) omp.cli identifier
3534  SmallVector<mlir::Value, 1> cliOperand;
3535  if (!parser.parseOptionalLParen()) {
3536  if (parser.parseOperand(cli) ||
3537  parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3538  return failure();
3539  }
3540 
3541  // We derive the type of tripCount from inductionVariable. MLIR requires the
3542  // type of tripCount to be known when calling resolveOperand so we have parse
3543  // the type before processing the inductionVariable.
3544  OpAsmParser::Argument inductionVariable;
3546  if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3547  parser.parseKeyword("in") || parser.parseKeyword("range") ||
3548  parser.parseLParen() || parser.parseOperand(tripcount) ||
3549  parser.parseRParen() ||
3550  parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3551  return failure();
3552 
3553  // Parse the loop body.
3554  Region *region = result.addRegion();
3555  if (parser.parseRegion(*region, {inductionVariable}))
3556  return failure();
3557 
3558  // We parsed the cli operand forst, but because it is optional, it must be
3559  // last in the operand list.
3560  result.operands.append(cliOperand);
3561 
3562  // Parse the optional attribute list.
3563  if (parser.parseOptionalAttrDict(result.attributes))
3564  return failure();
3565 
3566  return mlir::success();
3567 }
3568 
3569 LogicalResult CanonicalLoopOp::verify() {
3570  // The region's entry must accept the induction variable
3571  // It can also be empty if just created
3572  if (!getRegion().empty()) {
3573  Region &region = getRegion();
3574  if (region.getNumArguments() != 1)
3575  return emitOpError(
3576  "Canonical loop region must have exactly one argument");
3577 
3578  if (getInductionVar().getType() != getTripCount().getType())
3579  return emitOpError(
3580  "Region argument must be the same type as the trip count");
3581  }
3582 
3583  return success();
3584 }
3585 
3586 Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3587 
3588 std::pair<unsigned, unsigned>
3589 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3590  // No applyees
3591  return {0, 0};
3592 }
3593 
3594 std::pair<unsigned, unsigned>
3595 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3596  return getODSOperandIndexAndLength(odsIndex_cli);
3597 }
3598 
3599 //===----------------------------------------------------------------------===//
3600 // UnrollHeuristicOp
3601 //===----------------------------------------------------------------------===//
3602 
3603 void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3604  ::mlir::OperationState &odsState,
3605  ::mlir::Value cli) {
3606  odsState.addOperands(cli);
3607 }
3608 
3610  p << '(' << getApplyee() << ')';
3611 
3612  p.printOptionalAttrDict((*this)->getAttrs());
3613 }
3614 
3615 mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3616  ::mlir::OperationState &result) {
3617  auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3618 
3619  if (parser.parseLParen())
3620  return failure();
3621 
3623  if (parser.parseOperand(applyee) ||
3624  parser.resolveOperand(applyee, cliType, result.operands))
3625  return failure();
3626 
3627  if (parser.parseRParen())
3628  return failure();
3629 
3630  // Optional output loop (full unrolling has none)
3631  if (!parser.parseOptionalArrow()) {
3632  if (parser.parseLParen() || parser.parseRParen())
3633  return failure();
3634  }
3635 
3636  // Parse the optional attribute list.
3637  if (parser.parseOptionalAttrDict(result.attributes))
3638  return failure();
3639 
3640  return mlir::success();
3641 }
3642 
3643 std::pair<unsigned, unsigned>
3644 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3645  return getODSOperandIndexAndLength(odsIndex_applyee);
3646 }
3647 
3648 std::pair<unsigned, unsigned>
3649 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3650  return {0, 0};
3651 }
3652 
3653 //===----------------------------------------------------------------------===//
3654 // TileOp
3655 //===----------------------------------------------------------------------===//
3656 
3657 static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3658  OperandRange generatees,
3659  OperandRange applyees) {
3660  if (!generatees.empty())
3661  p << '(' << llvm::interleaved(generatees) << ')';
3662 
3663  if (!applyees.empty())
3664  p << " <- (" << llvm::interleaved(applyees) << ')';
3665 }
3666 
3667 static ParseResult parseLoopTransformClis(
3668  OpAsmParser &parser,
3671  if (parser.parseOptionalLess()) {
3672  // Syntax 1: generatees present
3673 
3674  if (parser.parseOperandList(generateesOperands,
3676  return failure();
3677 
3678  if (parser.parseLess())
3679  return failure();
3680  } else {
3681  // Syntax 2: generatees omitted
3682  }
3683 
3684  // Parse `<-` (`<` has already been parsed)
3685  if (parser.parseMinus())
3686  return failure();
3687 
3688  if (parser.parseOperandList(applyeesOperands,
3690  return failure();
3691 
3692  return success();
3693 }
3694 
3695 LogicalResult TileOp::verify() {
3696  if (getApplyees().empty())
3697  return emitOpError() << "must apply to at least one loop";
3698 
3699  if (getSizes().size() != getApplyees().size())
3700  return emitOpError() << "there must be one tile size for each applyee";
3701 
3702  if (!getGeneratees().empty() &&
3703  2 * getSizes().size() != getGeneratees().size())
3704  return emitOpError()
3705  << "expecting two times the number of generatees than applyees";
3706 
3707  DenseSet<Value> parentIVs;
3708 
3709  Value parent = getApplyees().front();
3710  for (auto &&applyee : llvm::drop_begin(getApplyees())) {
3711  auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
3712  auto [create, gen, cons] = decodeCli(applyee);
3713 
3714  if (!parentGen)
3715  return emitOpError() << "applyee CLI has no generator";
3716 
3717  auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3718  if (!parentGen)
3719  return emitOpError()
3720  << "currently only supports omp.canonical_loop as applyee";
3721 
3722  parentIVs.insert(parentLoop.getInductionVar());
3723 
3724  if (!gen)
3725  return emitOpError() << "applyee CLI has no generator";
3726  auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3727  if (!loop)
3728  return emitOpError()
3729  << "currently only supports omp.canonical_loop as applyee";
3730 
3731  // Canonical loop must be perfectly nested, i.e. the body of the parent must
3732  // only contain the omp.canonical_loop of the nested loops, and
3733  // omp.terminator
3734  bool isPerfectlyNested = [&]() {
3735  auto &parentBody = parentLoop.getRegion();
3736  if (!parentBody.hasOneBlock())
3737  return false;
3738  auto &parentBlock = parentBody.getBlocks().front();
3739 
3740  auto nestedLoopIt = parentBlock.begin();
3741  if (nestedLoopIt == parentBlock.end() ||
3742  (&*nestedLoopIt != loop.getOperation()))
3743  return false;
3744 
3745  auto termIt = std::next(nestedLoopIt);
3746  if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3747  return false;
3748 
3749  if (std::next(termIt) != parentBlock.end())
3750  return false;
3751 
3752  return true;
3753  }();
3754  if (!isPerfectlyNested)
3755  return emitOpError() << "tiled loop nest must be perfectly nested";
3756 
3757  if (parentIVs.contains(loop.getTripCount()))
3758  return emitOpError() << "tiled loop nest must be rectangular";
3759 
3760  parent = applyee;
3761  }
3762 
3763  // TODO: The tile sizes must be computed before the loop, but checking this
3764  // requires dominance analysis. For instance:
3765  //
3766  // %canonloop = omp.new_cli
3767  // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3768  // // write to %x
3769  // omp.terminator
3770  // }
3771  // %ts = llvm.load %x
3772  // omp.tile <- (%canonloop) sizes(%ts : i32)
3773 
3774  return success();
3775 }
3776 
3777 std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3778  return getODSOperandIndexAndLength(odsIndex_applyees);
3779 }
3780 
3781 std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3782  return getODSOperandIndexAndLength(odsIndex_generatees);
3783 }
3784 
3785 //===----------------------------------------------------------------------===//
3786 // Critical construct (2.17.1)
3787 //===----------------------------------------------------------------------===//
3788 
3789 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3790  const CriticalDeclareOperands &clauses) {
3791  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3792 }
3793 
3794 LogicalResult CriticalDeclareOp::verify() {
3795  return verifySynchronizationHint(*this, getHint());
3796 }
3797 
3798 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3799  if (getNameAttr()) {
3800  SymbolRefAttr symbolRef = getNameAttr();
3801  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3802  *this, symbolRef);
3803  if (!decl) {
3804  return emitOpError() << "expected symbol reference " << symbolRef
3805  << " to point to a critical declaration";
3806  }
3807  }
3808 
3809  return success();
3810 }
3811 
3812 //===----------------------------------------------------------------------===//
3813 // Ordered construct
3814 //===----------------------------------------------------------------------===//
3815 
3816 static LogicalResult verifyOrderedParent(Operation &op) {
3817  bool hasRegion = op.getNumRegions() > 0;
3818  auto loopOp = op.getParentOfType<LoopNestOp>();
3819  if (!loopOp) {
3820  if (hasRegion)
3821  return success();
3822 
3823  // TODO: Consider if this needs to be the case only for the standalone
3824  // variant of the ordered construct.
3825  return op.emitOpError() << "must be nested inside of a loop";
3826  }
3827 
3828  Operation *wrapper = loopOp->getParentOp();
3829  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3830  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3831  if (!orderedAttr)
3832  return op.emitOpError() << "the enclosing worksharing-loop region must "
3833  "have an ordered clause";
3834 
3835  if (hasRegion && orderedAttr.getInt() != 0)
3836  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3837  "have a parameter present";
3838 
3839  if (!hasRegion && orderedAttr.getInt() == 0)
3840  return op.emitOpError() << "the enclosing loop's ordered clause must "
3841  "have a parameter present";
3842  } else if (!isa<SimdOp>(wrapper)) {
3843  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3844  "or worksharing simd loop";
3845  }
3846  return success();
3847 }
3848 
3849 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3850  const OrderedOperands &clauses) {
3851  OrderedOp::build(builder, state, clauses.doacrossDependType,
3852  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3853 }
3854 
3855 LogicalResult OrderedOp::verify() {
3856  if (failed(verifyOrderedParent(**this)))
3857  return failure();
3858 
3859  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3860  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3861  return emitOpError() << "number of variables in depend clause does not "
3862  << "match number of iteration variables in the "
3863  << "doacross loop";
3864 
3865  return success();
3866 }
3867 
3868 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3869  const OrderedRegionOperands &clauses) {
3870  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3871 }
3872 
3873 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3874 
3875 //===----------------------------------------------------------------------===//
3876 // TaskwaitOp
3877 //===----------------------------------------------------------------------===//
3878 
3879 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3880  const TaskwaitOperands &clauses) {
3881  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3882  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3883  /*depend_vars=*/{}, /*nowait=*/nullptr);
3884 }
3885 
3886 //===----------------------------------------------------------------------===//
3887 // Verifier for AtomicReadOp
3888 //===----------------------------------------------------------------------===//
3889 
3890 LogicalResult AtomicReadOp::verify() {
3891  if (verifyCommon().failed())
3892  return mlir::failure();
3893 
3894  if (auto mo = getMemoryOrder()) {
3895  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3896  *mo == ClauseMemoryOrderKind::Release) {
3897  return emitError(
3898  "memory-order must not be acq_rel or release for atomic reads");
3899  }
3900  }
3901  return verifySynchronizationHint(*this, getHint());
3902 }
3903 
3904 //===----------------------------------------------------------------------===//
3905 // Verifier for AtomicWriteOp
3906 //===----------------------------------------------------------------------===//
3907 
3908 LogicalResult AtomicWriteOp::verify() {
3909  if (verifyCommon().failed())
3910  return mlir::failure();
3911 
3912  if (auto mo = getMemoryOrder()) {
3913  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3914  *mo == ClauseMemoryOrderKind::Acquire) {
3915  return emitError(
3916  "memory-order must not be acq_rel or acquire for atomic writes");
3917  }
3918  }
3919  return verifySynchronizationHint(*this, getHint());
3920 }
3921 
3922 //===----------------------------------------------------------------------===//
3923 // Verifier for AtomicUpdateOp
3924 //===----------------------------------------------------------------------===//
3925 
3926 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3927  PatternRewriter &rewriter) {
3928  if (op.isNoOp()) {
3929  rewriter.eraseOp(op);
3930  return success();
3931  }
3932  if (Value writeVal = op.getWriteOpVal()) {
3933  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3934  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3935  return success();
3936  }
3937  return failure();
3938 }
3939 
3940 LogicalResult AtomicUpdateOp::verify() {
3941  if (verifyCommon().failed())
3942  return mlir::failure();
3943 
3944  if (auto mo = getMemoryOrder()) {
3945  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3946  *mo == ClauseMemoryOrderKind::Acquire) {
3947  return emitError(
3948  "memory-order must not be acq_rel or acquire for atomic updates");
3949  }
3950  }
3951 
3952  return verifySynchronizationHint(*this, getHint());
3953 }
3954 
3955 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3956 
3957 //===----------------------------------------------------------------------===//
3958 // Verifier for AtomicCaptureOp
3959 //===----------------------------------------------------------------------===//
3960 
3961 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3962  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3963  return op;
3964  return dyn_cast<AtomicReadOp>(getSecondOp());
3965 }
3966 
3967 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3968  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3969  return op;
3970  return dyn_cast<AtomicWriteOp>(getSecondOp());
3971 }
3972 
3973 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3974  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3975  return op;
3976  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3977 }
3978 
3979 LogicalResult AtomicCaptureOp::verify() {
3980  return verifySynchronizationHint(*this, getHint());
3981 }
3982 
3983 LogicalResult AtomicCaptureOp::verifyRegions() {
3984  if (verifyRegionsCommon().failed())
3985  return mlir::failure();
3986 
3987  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3988  return emitOpError(
3989  "operations inside capture region must not have hint clause");
3990 
3991  if (getFirstOp()->getAttr("memory_order") ||
3992  getSecondOp()->getAttr("memory_order"))
3993  return emitOpError(
3994  "operations inside capture region must not have memory_order clause");
3995  return success();
3996 }
3997 
3998 //===----------------------------------------------------------------------===//
3999 // CancelOp
4000 //===----------------------------------------------------------------------===//
4001 
4002 void CancelOp::build(OpBuilder &builder, OperationState &state,
4003  const CancelOperands &clauses) {
4004  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4005 }
4006 
4008  Operation *parent = thisOp->getParentOp();
4009  while (parent) {
4010  if (parent->getDialect() == thisOp->getDialect())
4011  return parent;
4012  parent = parent->getParentOp();
4013  }
4014  return nullptr;
4015 }
4016 
4017 LogicalResult CancelOp::verify() {
4018  ClauseCancellationConstructType cct = getCancelDirective();
4019  // The next OpenMP operation in the chain of parents
4020  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4021  if (!structuralParent)
4022  return emitOpError() << "Orphaned cancel construct";
4023 
4024  if ((cct == ClauseCancellationConstructType::Parallel) &&
4025  !mlir::isa<ParallelOp>(structuralParent)) {
4026  return emitOpError() << "cancel parallel must appear "
4027  << "inside a parallel region";
4028  }
4029  if (cct == ClauseCancellationConstructType::Loop) {
4030  // structural parent will be omp.loop_nest, directly nested inside
4031  // omp.wsloop
4032  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4033 
4034  if (!wsloopOp) {
4035  return emitOpError()
4036  << "cancel loop must appear inside a worksharing-loop region";
4037  }
4038  if (wsloopOp.getNowaitAttr()) {
4039  return emitError() << "A worksharing construct that is canceled "
4040  << "must not have a nowait clause";
4041  }
4042  if (wsloopOp.getOrderedAttr()) {
4043  return emitError() << "A worksharing construct that is canceled "
4044  << "must not have an ordered clause";
4045  }
4046 
4047  } else if (cct == ClauseCancellationConstructType::Sections) {
4048  // structural parent will be an omp.section, directly nested inside
4049  // omp.sections
4050  auto sectionsOp =
4051  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4052  if (!sectionsOp) {
4053  return emitOpError() << "cancel sections must appear "
4054  << "inside a sections region";
4055  }
4056  if (sectionsOp.getNowait()) {
4057  return emitError() << "A sections construct that is canceled "
4058  << "must not have a nowait clause";
4059  }
4060  }
4061  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4062  (!mlir::isa<omp::TaskOp>(structuralParent) &&
4063  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4064  return emitOpError() << "cancel taskgroup must appear "
4065  << "inside a task region";
4066  }
4067  return success();
4068 }
4069 
4070 //===----------------------------------------------------------------------===//
4071 // CancellationPointOp
4072 //===----------------------------------------------------------------------===//
4073 
4074 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4075  const CancellationPointOperands &clauses) {
4076  CancellationPointOp::build(builder, state, clauses.cancelDirective);
4077 }
4078 
4079 LogicalResult CancellationPointOp::verify() {
4080  ClauseCancellationConstructType cct = getCancelDirective();
4081  // The next OpenMP operation in the chain of parents
4082  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4083  if (!structuralParent)
4084  return emitOpError() << "Orphaned cancellation point";
4085 
4086  if ((cct == ClauseCancellationConstructType::Parallel) &&
4087  !mlir::isa<ParallelOp>(structuralParent)) {
4088  return emitOpError() << "cancellation point parallel must appear "
4089  << "inside a parallel region";
4090  }
4091  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4092  // find the wsloop
4093  if ((cct == ClauseCancellationConstructType::Loop) &&
4094  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4095  return emitOpError() << "cancellation point loop must appear "
4096  << "inside a worksharing-loop region";
4097  }
4098  if ((cct == ClauseCancellationConstructType::Sections) &&
4099  !mlir::isa<omp::SectionOp>(structuralParent)) {
4100  return emitOpError() << "cancellation point sections must appear "
4101  << "inside a sections region";
4102  }
4103  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4104  !mlir::isa<omp::TaskOp>(structuralParent)) {
4105  return emitOpError() << "cancellation point taskgroup must appear "
4106  << "inside a task region";
4107  }
4108  return success();
4109 }
4110 
4111 //===----------------------------------------------------------------------===//
4112 // MapBoundsOp
4113 //===----------------------------------------------------------------------===//
4114 
4115 LogicalResult MapBoundsOp::verify() {
4116  auto extent = getExtent();
4117  auto upperbound = getUpperBound();
4118  if (!extent && !upperbound)
4119  return emitError("expected extent or upperbound.");
4120  return success();
4121 }
4122 
4123 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4124  TypeRange /*result_types*/, StringAttr symName,
4125  TypeAttr type) {
4126  PrivateClauseOp::build(
4127  odsBuilder, odsState, symName, type,
4129  DataSharingClauseType::Private));
4130 }
4131 
4132 LogicalResult PrivateClauseOp::verifyRegions() {
4133  Type argType = getArgType();
4134  auto verifyTerminator = [&](Operation *terminator,
4135  bool yieldsValue) -> LogicalResult {
4136  if (!terminator->getBlock()->getSuccessors().empty())
4137  return success();
4138 
4139  if (!llvm::isa<YieldOp>(terminator))
4140  return mlir::emitError(terminator->getLoc())
4141  << "expected exit block terminator to be an `omp.yield` op.";
4142 
4143  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4144  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4145 
4146  if (!yieldsValue) {
4147  if (yieldedTypes.empty())
4148  return success();
4149 
4150  return mlir::emitError(terminator->getLoc())
4151  << "Did not expect any values to be yielded.";
4152  }
4153 
4154  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4155  return success();
4156 
4157  auto error = mlir::emitError(yieldOp.getLoc())
4158  << "Invalid yielded value. Expected type: " << argType
4159  << ", got: ";
4160 
4161  if (yieldedTypes.empty())
4162  error << "None";
4163  else
4164  error << yieldedTypes;
4165 
4166  return error;
4167  };
4168 
4169  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4170  StringRef regionName,
4171  bool yieldsValue) -> LogicalResult {
4172  assert(!region.empty());
4173 
4174  if (region.getNumArguments() != expectedNumArgs)
4175  return mlir::emitError(region.getLoc())
4176  << "`" << regionName << "`: "
4177  << "expected " << expectedNumArgs
4178  << " region arguments, got: " << region.getNumArguments();
4179 
4180  for (Block &block : region) {
4181  // MLIR will verify the absence of the terminator for us.
4182  if (!block.mightHaveTerminator())
4183  continue;
4184 
4185  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4186  return failure();
4187  }
4188 
4189  return success();
4190  };
4191 
4192  // Ensure all of the region arguments have the same type
4193  for (Region *region : getRegions())
4194  for (Type ty : region->getArgumentTypes())
4195  if (ty != argType)
4196  return emitError() << "Region argument type mismatch: got " << ty
4197  << " expected " << argType << ".";
4198 
4199  mlir::Region &initRegion = getInitRegion();
4200  if (!initRegion.empty() &&
4201  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4202  /*yieldsValue=*/true)))
4203  return failure();
4204 
4205  DataSharingClauseType dsType = getDataSharingType();
4206 
4207  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4208  return emitError("`private` clauses do not require a `copy` region.");
4209 
4210  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4211  return emitError(
4212  "`firstprivate` clauses require at least a `copy` region.");
4213 
4214  if (dsType == DataSharingClauseType::FirstPrivate &&
4215  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4216  /*yieldsValue=*/true)))
4217  return failure();
4218 
4219  if (!getDeallocRegion().empty() &&
4220  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4221  /*yieldsValue=*/false)))
4222  return failure();
4223 
4224  return success();
4225 }
4226 
4227 //===----------------------------------------------------------------------===//
4228 // Spec 5.2: Masked construct (10.5)
4229 //===----------------------------------------------------------------------===//
4230 
4231 void MaskedOp::build(OpBuilder &builder, OperationState &state,
4232  const MaskedOperands &clauses) {
4233  MaskedOp::build(builder, state, clauses.filteredThreadId);
4234 }
4235 
4236 //===----------------------------------------------------------------------===//
4237 // Spec 5.2: Scan construct (5.6)
4238 //===----------------------------------------------------------------------===//
4239 
4240 void ScanOp::build(OpBuilder &builder, OperationState &state,
4241  const ScanOperands &clauses) {
4242  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4243 }
4244 
4245 LogicalResult ScanOp::verify() {
4246  if (hasExclusiveVars() == hasInclusiveVars())
4247  return emitError(
4248  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4249  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4250  if (parentWsLoopOp.getReductionModAttr() &&
4251  parentWsLoopOp.getReductionModAttr().getValue() ==
4252  ReductionModifier::inscan)
4253  return success();
4254  }
4255  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4256  if (parentSimdOp.getReductionModAttr() &&
4257  parentSimdOp.getReductionModAttr().getValue() ==
4258  ReductionModifier::inscan)
4259  return success();
4260  }
4261  return emitError("SCAN directive needs to be enclosed within a parent "
4262  "worksharing loop construct or SIMD construct with INSCAN "
4263  "reduction modifier");
4264 }
4265 
4266 /// Verifies align clause in allocate directive
4267 
4268 LogicalResult AllocateDirOp::verify() {
4269  std::optional<uint64_t> align = this->getAlign();
4270 
4271  if (align.has_value()) {
4272  if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4273  return emitError() << "ALIGN value : " << align.value()
4274  << " must be power of 2";
4275  }
4276 
4277  return success();
4278 }
4279 
4280 //===----------------------------------------------------------------------===//
4281 // TargetAllocMemOp
4282 //===----------------------------------------------------------------------===//
4283 
4284 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4285  return getInTypeAttr().getValue();
4286 }
4287 
4288 /// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4289 /// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4290 /// attr-dict-without-keyword
4291 static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4292  mlir::OperationState &result) {
4293  auto &builder = parser.getBuilder();
4294  bool hasOperands = false;
4295  std::int32_t typeparamsSize = 0;
4296 
4297  // Parse device number as a new operand
4299  mlir::Type deviceType;
4300  if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4301  return mlir::failure();
4302  if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4303  return mlir::failure();
4304  if (parser.parseComma())
4305  return mlir::failure();
4306 
4307  mlir::Type intype;
4308  if (parser.parseType(intype))
4309  return mlir::failure();
4310  result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4313  if (!parser.parseOptionalLParen()) {
4314  // parse the LEN params of the derived type. (<params> : <types>)
4315  if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
4316  parser.parseColonTypeList(typeVec) || parser.parseRParen())
4317  return mlir::failure();
4318  typeparamsSize = operands.size();
4319  hasOperands = true;
4320  }
4321  std::int32_t shapeSize = 0;
4322  if (!parser.parseOptionalComma()) {
4323  // parse size to scale by, vector of n dimensions of type index
4325  return mlir::failure();
4326  shapeSize = operands.size() - typeparamsSize;
4327  auto idxTy = builder.getIndexType();
4328  for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4329  typeVec.push_back(idxTy);
4330  hasOperands = true;
4331  }
4332  if (hasOperands &&
4333  parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4334  result.operands))
4335  return mlir::failure();
4336 
4337  mlir::Type restype = builder.getIntegerType(64);
4338  if (!restype) {
4339  parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4340  return mlir::failure();
4341  }
4342  llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4343  result.addAttribute("operandSegmentSizes",
4344  builder.getDenseI32ArrayAttr(segmentSizes));
4345  if (parser.parseOptionalAttrDict(result.attributes) ||
4346  parser.addTypeToList(restype, result.types))
4347  return mlir::failure();
4348  return mlir::success();
4349 }
4350 
4351 mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4352  mlir::OperationState &result) {
4353  return parseTargetAllocMemOp(parser, result);
4354 }
4355 
4357  p << " ";
4358  p.printOperand(getDevice());
4359  p << " : ";
4360  p << getDevice().getType();
4361  p << ", ";
4362  p << getInType();
4363  if (!getTypeparams().empty()) {
4364  p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4365  }
4366  for (auto sh : getShape()) {
4367  p << ", ";
4368  p.printOperand(sh);
4369  }
4370  p.printOptionalAttrDict((*this)->getAttrs(),
4371  {"in_type", "operandSegmentSizes"});
4372 }
4373 
4374 llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4375  mlir::Type outType = getType();
4376  if (!mlir::dyn_cast<IntegerType>(outType))
4377  return emitOpError("must be a integer type");
4378  return mlir::success();
4379 }
4380 
4381 //===----------------------------------------------------------------------===//
4382 // WorkdistributeOp
4383 //===----------------------------------------------------------------------===//
4384 
4385 LogicalResult WorkdistributeOp::verify() {
4386  // Check that region exists and is not empty
4387  Region &region = getRegion();
4388  if (region.empty())
4389  return emitOpError("region cannot be empty");
4390  // Verify single entry point.
4391  Block &entryBlock = region.front();
4392  if (entryBlock.empty())
4393  return emitOpError("region must contain a structured block");
4394  // Verify single exit point.
4395  bool hasTerminator = false;
4396  for (Block &block : region) {
4397  if (isa<TerminatorOp>(block.back())) {
4398  if (hasTerminator) {
4399  return emitOpError("region must have exactly one terminator");
4400  }
4401  hasTerminator = true;
4402  }
4403  }
4404  if (!hasTerminator) {
4405  return emitOpError("region must be terminated with omp.terminator");
4406  }
4407  auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4408  // No implicit barrier at end
4409  if (isa<BarrierOp>(op)) {
4410  return emitOpError(
4411  "explicit barriers are not allowed in workdistribute region");
4412  }
4413  // Check for invalid nested constructs
4414  if (isa<ParallelOp>(op)) {
4415  return emitOpError(
4416  "nested parallel constructs not allowed in workdistribute");
4417  }
4418  if (isa<TeamsOp>(op)) {
4419  return emitOpError(
4420  "nested teams constructs not allowed in workdistribute");
4421  }
4422  return WalkResult::advance();
4423  });
4424  if (walkResult.wasInterrupted())
4425  return failure();
4426 
4427  Operation *parentOp = (*this)->getParentOp();
4428  if (!llvm::dyn_cast<TeamsOp>(parentOp))
4429  return emitOpError("workdistribute must be nested under teams");
4430  return success();
4431 }
4432 
4433 #define GET_ATTRDEF_CLASSES
4434 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4435 
4436 #define GET_OP_CLASSES
4437 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4438 
4439 #define GET_TYPEDEF_CLASSES
4440 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1365
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool isUnique(It begin, It end)
Definition: ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:270
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Definition: Diagnostics.h:228
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:352
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:769
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:306
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
Definition: LoopUtils.cpp:1361
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1584
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:881
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.