MLIR  22.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/SymbolTable.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Support/InterleavedRange.h"
36 #include <cstddef>
37 #include <iterator>
38 #include <optional>
39 #include <variant>
40 
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::omp;
48 
49 static ArrayAttr makeArrayAttr(MLIRContext *context,
51  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52 }
53 
54 static DenseBoolArrayAttr
56  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57 }
58 
59 static DenseI64ArrayAttr
61  return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
62 }
63 
64 namespace {
65 struct MemRefPointerLikeModel
66  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
67  MemRefType> {
68  Type getElementType(Type pointer) const {
69  return llvm::cast<MemRefType>(pointer).getElementType();
70  }
71 };
72 
73 struct LLVMPointerPointerLikeModel
74  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75  LLVM::LLVMPointerType> {
76  Type getElementType(Type pointer) const { return Type(); }
77 };
78 } // namespace
79 
80 /// Generate a name of a canonical loop nest of the format
81 /// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
82 /// argument index of an operation that has multiple regions, if the operation
83 /// has multiple regions.
84 /// `_s<idx>` identifies the position of an operation within a region, where
85 /// only operations that may potentially contain loops ("container operations"
86 /// i.e. have region arguments) are counted. Again, it is omitted if there is
87 /// only one such operation in a region. If there are canonical loops nested
88 /// inside each other, also may also use the format `_d<num>` where <num> is the
89 /// nesting depth of the loop.
90 ///
91 /// The generated name is a best-effort to make canonical loop unique within an
92 /// SSA namespace. This also means that regions with IsolatedFromAbove property
93 /// do not consider any parents or siblings.
94 static std::string generateLoopNestingName(StringRef prefix,
95  CanonicalLoopOp op) {
96  struct Component {
97  /// If true, this component describes a region operand of an operation (the
98  /// operand's owner) If false, this component describes an operation located
99  /// in a parent region
100  bool isRegionArgOfOp;
101  bool skip = false;
102  bool isUnique = false;
103 
104  size_t idx;
105  Operation *op;
106  Region *parentRegion;
107  size_t loopDepth;
108 
109  Operation *&getOwnerOp() {
110  assert(isRegionArgOfOp && "Must describe a region operand");
111  return op;
112  }
113  size_t &getArgIdx() {
114  assert(isRegionArgOfOp && "Must describe a region operand");
115  return idx;
116  }
117 
118  Operation *&getContainerOp() {
119  assert(!isRegionArgOfOp && "Must describe a operation of a region");
120  return op;
121  }
122  size_t &getOpPos() {
123  assert(!isRegionArgOfOp && "Must describe a operation of a region");
124  return idx;
125  }
126  bool isLoopOp() const {
127  assert(!isRegionArgOfOp && "Must describe a operation of a region");
128  return isa<CanonicalLoopOp>(op);
129  }
130  Region *&getParentRegion() {
131  assert(!isRegionArgOfOp && "Must describe a operation of a region");
132  return parentRegion;
133  }
134  size_t &getLoopDepth() {
135  assert(!isRegionArgOfOp && "Must describe a operation of a region");
136  return loopDepth;
137  }
138 
139  void skipIf(bool v = true) { skip = skip || v; }
140  };
141 
142  // List of ancestors, from inner to outer.
143  // Alternates between
144  // * region argument of an operation
145  // * operation within a region
146  SmallVector<Component> components;
147 
148  // Gather a list of parent regions and operations, and the position within
149  // their parent
150  Operation *o = op.getOperation();
151  while (o) {
152  // Operation within a region
153  Region *r = o->getParentRegion();
154  if (!r)
155  break;
156 
157  llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
158  size_t idx = 0;
159  bool found = false;
160  size_t sequentialIdx = -1;
161  bool isOnlyContainerOp = true;
162  for (Block *b : traversal) {
163  for (Operation &op : *b) {
164  if (&op == o && !found) {
165  sequentialIdx = idx;
166  found = true;
167  }
168  if (op.getNumRegions()) {
169  idx += 1;
170  if (idx > 1)
171  isOnlyContainerOp = false;
172  }
173  if (found && !isOnlyContainerOp)
174  break;
175  }
176  }
177 
178  Component &containerOpInRegion = components.emplace_back();
179  containerOpInRegion.isRegionArgOfOp = false;
180  containerOpInRegion.isUnique = isOnlyContainerOp;
181  containerOpInRegion.getContainerOp() = o;
182  containerOpInRegion.getOpPos() = sequentialIdx;
183  containerOpInRegion.getParentRegion() = r;
184 
185  Operation *parent = r->getParentOp();
186 
187  // Region argument of an operation
188  Component &regionArgOfOperation = components.emplace_back();
189  regionArgOfOperation.isRegionArgOfOp = true;
190  regionArgOfOperation.isUnique = true;
191  regionArgOfOperation.getArgIdx() = 0;
192  regionArgOfOperation.getOwnerOp() = parent;
193 
194  // The IsolatedFromAbove trait of the parent operation implies that each
195  // individual region argument has its own separate namespace, so no
196  // ambiguity.
197  if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
198  break;
199 
200  // Component only needed if operation has multiple region operands. Region
201  // arguments may be optional, but we currently do not consider this.
202  if (parent->getRegions().size() > 1) {
203  auto getRegionIndex = [](Operation *o, Region *r) {
204  for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
205  if (&region == r)
206  return idx;
207  }
208  llvm_unreachable("Region not child of its parent operation");
209  };
210  regionArgOfOperation.isUnique = false;
211  regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
212  }
213 
214  // next parent
215  o = parent;
216  }
217 
218  // Determine whether a region-argument component is not needed
219  for (Component &c : components)
220  c.skipIf(c.isRegionArgOfOp && c.isUnique);
221 
222  // Find runs of nested loops and determine each loop's depth in the loop nest
223  size_t numSurroundingLoops = 0;
224  for (Component &c : llvm::reverse(components)) {
225  if (c.skip)
226  continue;
227 
228  // non-skipped multi-argument operands interrupt the loop nest
229  if (c.isRegionArgOfOp) {
230  numSurroundingLoops = 0;
231  continue;
232  }
233 
234  // Multiple loops in a region means each of them is the outermost loop of a
235  // new loop nest
236  if (!c.isUnique)
237  numSurroundingLoops = 0;
238 
239  c.getLoopDepth() = numSurroundingLoops;
240 
241  // Next loop is surrounded by one more loop
242  if (isa<CanonicalLoopOp>(c.getContainerOp()))
243  numSurroundingLoops += 1;
244  }
245 
246  // In loop nests, skip all but the innermost loop that contains the depth
247  // number
248  bool isLoopNest = false;
249  for (Component &c : components) {
250  if (c.skip || c.isRegionArgOfOp)
251  continue;
252 
253  if (!isLoopNest && c.getLoopDepth() >= 1) {
254  // Innermost loop of a loop nest of at least two loops
255  isLoopNest = true;
256  } else if (isLoopNest) {
257  // Non-innermost loop of a loop nest
258  c.skipIf(c.isUnique);
259 
260  // If there is no surrounding loop left, this must have been the outermost
261  // loop; leave loop-nest mode for the next iteration
262  if (c.getLoopDepth() == 0)
263  isLoopNest = false;
264  }
265  }
266 
267  // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
268  // we mark them as skipped only after computing loop nests)
269  for (Component &c : components)
270  c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271  !isa<CanonicalLoopOp>(c.getContainerOp()));
272 
273  // Components can be skipped if they are already disambiguated by their parent
274  // (or does not have a parent)
275  bool newRegion = true;
276  for (Component &c : llvm::reverse(components)) {
277  c.skipIf(newRegion && c.isUnique);
278 
279  // non-skipped components disambiguate unique children
280  if (!c.skip)
281  newRegion = true;
282 
283  // ...except canonical loops that need a suffix for each nest
284  if (!c.isRegionArgOfOp && c.getContainerOp())
285  newRegion = false;
286  }
287 
288  // Compile the nesting name string
289  SmallString<64> Name{prefix};
290  llvm::raw_svector_ostream NameOS(Name);
291  for (auto &c : llvm::reverse(components)) {
292  if (c.skip)
293  continue;
294 
295  if (c.isRegionArgOfOp)
296  NameOS << "_r" << c.getArgIdx();
297  else if (c.getLoopDepth() >= 1)
298  NameOS << "_d" << c.getLoopDepth();
299  else
300  NameOS << "_s" << c.getOpPos();
301  }
302 
303  return NameOS.str().str();
304 }
305 
306 void OpenMPDialect::initialize() {
307  addOperations<
308 #define GET_OP_LIST
309 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
310  >();
311  addAttributes<
312 #define GET_ATTRDEF_LIST
313 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
314  >();
315  addTypes<
316 #define GET_TYPEDEF_LIST
317 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
318  >();
319 
320  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
321 
322  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
323  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
324  *getContext());
325 
326  // Attach default offload module interface to module op to access
327  // offload functionality through
328  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
329  *getContext());
330 
331  // Attach default declare target interfaces to operations which can be marked
332  // as declare target (Global Operations and Functions/Subroutines in dialects
333  // that Fortran (or other languages that lower to MLIR) translates too
334  mlir::LLVM::GlobalOp::attachInterface<
336  *getContext());
337  mlir::LLVM::LLVMFuncOp::attachInterface<
339  *getContext());
340  mlir::func::FuncOp::attachInterface<
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // Parser and printer for Allocate Clause
346 //===----------------------------------------------------------------------===//
347 
348 /// Parse an allocate clause with allocators and a list of operands with types.
349 ///
350 /// allocate-operand-list :: = allocate-operand |
351 /// allocator-operand `,` allocate-operand-list
352 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
353 /// ssa-id-and-type ::= ssa-id `:` type
354 static ParseResult parseAllocateAndAllocator(
355  OpAsmParser &parser,
357  SmallVectorImpl<Type> &allocateTypes,
359  SmallVectorImpl<Type> &allocatorTypes) {
360 
361  return parser.parseCommaSeparatedList([&]() {
363  Type type;
364  if (parser.parseOperand(operand) || parser.parseColonType(type))
365  return failure();
366  allocatorVars.push_back(operand);
367  allocatorTypes.push_back(type);
368  if (parser.parseArrow())
369  return failure();
370  if (parser.parseOperand(operand) || parser.parseColonType(type))
371  return failure();
372 
373  allocateVars.push_back(operand);
374  allocateTypes.push_back(type);
375  return success();
376  });
377 }
378 
379 /// Print allocate clause
381  OperandRange allocateVars,
382  TypeRange allocateTypes,
383  OperandRange allocatorVars,
384  TypeRange allocatorTypes) {
385  for (unsigned i = 0; i < allocateVars.size(); ++i) {
386  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
387  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
388  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
389  }
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // Parser and printer for a clause attribute (StringEnumAttr)
394 //===----------------------------------------------------------------------===//
395 
396 template <typename ClauseAttr>
397 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
398  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
399  StringRef enumStr;
400  SMLoc loc = parser.getCurrentLocation();
401  if (parser.parseKeyword(&enumStr))
402  return failure();
403  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404  attr = ClauseAttr::get(parser.getContext(), *enumValue);
405  return success();
406  }
407  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
408 }
409 
410 template <typename ClauseAttr>
411 static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
412  p << stringifyEnum(attr.getValue());
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // Parser and printer for Linear Clause
417 //===----------------------------------------------------------------------===//
418 
419 /// linear ::= `linear` `(` linear-list `)`
420 /// linear-list := linear-val | linear-val linear-list
421 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
422 static ParseResult parseLinearClause(
423  OpAsmParser &parser,
425  SmallVectorImpl<Type> &linearTypes,
427  return parser.parseCommaSeparatedList([&]() {
429  Type type;
431  if (parser.parseOperand(var) || parser.parseEqual() ||
432  parser.parseOperand(stepVar) || parser.parseColonType(type))
433  return failure();
434 
435  linearVars.push_back(var);
436  linearTypes.push_back(type);
437  linearStepVars.push_back(stepVar);
438  return success();
439  });
440 }
441 
442 /// Print Linear Clause
444  ValueRange linearVars, TypeRange linearTypes,
445  ValueRange linearStepVars) {
446  size_t linearVarsSize = linearVars.size();
447  for (unsigned i = 0; i < linearVarsSize; ++i) {
448  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
449  p << linearVars[i];
450  if (linearStepVars.size() > i)
451  p << " = " << linearStepVars[i];
452  p << " : " << linearVars[i].getType() << separator;
453  }
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // Verifier for Nontemporal Clause
458 //===----------------------------------------------------------------------===//
459 
460 static LogicalResult verifyNontemporalClause(Operation *op,
461  OperandRange nontemporalVars) {
462 
463  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
464  DenseSet<Value> nontemporalItems;
465  for (const auto &it : nontemporalVars)
466  if (!nontemporalItems.insert(it).second)
467  return op->emitOpError() << "nontemporal variable used more than once";
468 
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // Parser, verifier and printer for Aligned Clause
474 //===----------------------------------------------------------------------===//
475 static LogicalResult verifyAlignedClause(Operation *op,
476  std::optional<ArrayAttr> alignments,
477  OperandRange alignedVars) {
478  // Check if number of alignment values equals to number of aligned variables
479  if (!alignedVars.empty()) {
480  if (!alignments || alignments->size() != alignedVars.size())
481  return op->emitOpError()
482  << "expected as many alignment values as aligned variables";
483  } else {
484  if (alignments)
485  return op->emitOpError() << "unexpected alignment values attribute";
486  return success();
487  }
488 
489  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
490  DenseSet<Value> alignedItems;
491  for (auto it : alignedVars)
492  if (!alignedItems.insert(it).second)
493  return op->emitOpError() << "aligned variable used more than once";
494 
495  if (!alignments)
496  return success();
497 
498  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
499  for (unsigned i = 0; i < (*alignments).size(); ++i) {
500  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
501  if (intAttr.getValue().sle(0))
502  return op->emitOpError() << "alignment should be greater than 0";
503  } else {
504  return op->emitOpError() << "expected integer alignment";
505  }
506  }
507 
508  return success();
509 }
510 
511 /// aligned ::= `aligned` `(` aligned-list `)`
512 /// aligned-list := aligned-val | aligned-val aligned-list
513 /// aligned-val := ssa-id-and-type `->` alignment
514 static ParseResult
517  SmallVectorImpl<Type> &alignedTypes,
518  ArrayAttr &alignmentsAttr) {
519  SmallVector<Attribute> alignmentVec;
520  if (failed(parser.parseCommaSeparatedList([&]() {
521  if (parser.parseOperand(alignedVars.emplace_back()) ||
522  parser.parseColonType(alignedTypes.emplace_back()) ||
523  parser.parseArrow() ||
524  parser.parseAttribute(alignmentVec.emplace_back())) {
525  return failure();
526  }
527  return success();
528  })))
529  return failure();
530  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
531  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
532  return success();
533 }
534 
535 /// Print Aligned Clause
537  ValueRange alignedVars, TypeRange alignedTypes,
538  std::optional<ArrayAttr> alignments) {
539  for (unsigned i = 0; i < alignedVars.size(); ++i) {
540  if (i != 0)
541  p << ", ";
542  p << alignedVars[i] << " : " << alignedVars[i].getType();
543  p << " -> " << (*alignments)[i];
544  }
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Parser, printer and verifier for Schedule Clause
549 //===----------------------------------------------------------------------===//
550 
551 static ParseResult
553  SmallVectorImpl<SmallString<12>> &modifiers) {
554  if (modifiers.size() > 2)
555  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
556  for (const auto &mod : modifiers) {
557  // Translate the string. If it has no value, then it was not a valid
558  // modifier!
559  auto symbol = symbolizeScheduleModifier(mod);
560  if (!symbol)
561  return parser.emitError(parser.getNameLoc())
562  << " unknown modifier type: " << mod;
563  }
564 
565  // If we have one modifier that is "simd", then stick a "none" modiifer in
566  // index 0.
567  if (modifiers.size() == 1) {
568  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
569  modifiers.push_back(modifiers[0]);
570  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
571  }
572  } else if (modifiers.size() == 2) {
573  // If there are two modifier:
574  // First modifier should not be simd, second one should be simd
575  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
576  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
577  return parser.emitError(parser.getNameLoc())
578  << " incorrect modifier order";
579  }
580  return success();
581 }
582 
583 /// schedule ::= `schedule` `(` sched-list `)`
584 /// sched-list ::= sched-val | sched-val sched-list |
585 /// sched-val `,` sched-modifier
586 /// sched-val ::= sched-with-chunk | sched-wo-chunk
587 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
588 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
589 /// sched-wo-chunk ::= `auto` | `runtime`
590 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
591 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
592 static ParseResult
593 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
594  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
595  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
596  Type &chunkType) {
597  StringRef keyword;
598  if (parser.parseKeyword(&keyword))
599  return failure();
600  std::optional<mlir::omp::ClauseScheduleKind> schedule =
601  symbolizeClauseScheduleKind(keyword);
602  if (!schedule)
603  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
604 
605  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
606  switch (*schedule) {
607  case ClauseScheduleKind::Static:
608  case ClauseScheduleKind::Dynamic:
609  case ClauseScheduleKind::Guided:
610  if (succeeded(parser.parseOptionalEqual())) {
611  chunkSize = OpAsmParser::UnresolvedOperand{};
612  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
613  return failure();
614  } else {
615  chunkSize = std::nullopt;
616  }
617  break;
618  case ClauseScheduleKind::Auto:
620  chunkSize = std::nullopt;
621  }
622 
623  // If there is a comma, we have one or more modifiers..
624  SmallVector<SmallString<12>> modifiers;
625  while (succeeded(parser.parseOptionalComma())) {
626  StringRef mod;
627  if (parser.parseKeyword(&mod))
628  return failure();
629  modifiers.push_back(mod);
630  }
631 
632  if (verifyScheduleModifiers(parser, modifiers))
633  return failure();
634 
635  if (!modifiers.empty()) {
636  SMLoc loc = parser.getCurrentLocation();
637  if (std::optional<ScheduleModifier> mod =
638  symbolizeScheduleModifier(modifiers[0])) {
639  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
640  } else {
641  return parser.emitError(loc, "invalid schedule modifier");
642  }
643  // Only SIMD attribute is allowed here!
644  if (modifiers.size() > 1) {
645  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
646  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
647  }
648  }
649 
650  return success();
651 }
652 
653 /// Print schedule clause
655  ClauseScheduleKindAttr scheduleKind,
656  ScheduleModifierAttr scheduleMod,
657  UnitAttr scheduleSimd, Value scheduleChunk,
658  Type scheduleChunkType) {
659  p << stringifyClauseScheduleKind(scheduleKind.getValue());
660  if (scheduleChunk)
661  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
662  if (scheduleMod)
663  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
664  if (scheduleSimd)
665  p << ", simd";
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // Parser and printer for Order Clause
670 //===----------------------------------------------------------------------===//
671 
672 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
673 // order-modifier ::= reproducible | unconstrained
674 static ParseResult parseOrderClause(OpAsmParser &parser,
675  ClauseOrderKindAttr &order,
676  OrderModifierAttr &orderMod) {
677  StringRef enumStr;
678  SMLoc loc = parser.getCurrentLocation();
679  if (parser.parseKeyword(&enumStr))
680  return failure();
681  if (std::optional<OrderModifier> enumValue =
682  symbolizeOrderModifier(enumStr)) {
683  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
684  if (parser.parseOptionalColon())
685  return failure();
686  loc = parser.getCurrentLocation();
687  if (parser.parseKeyword(&enumStr))
688  return failure();
689  }
690  if (std::optional<ClauseOrderKind> enumValue =
691  symbolizeClauseOrderKind(enumStr)) {
692  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
693  return success();
694  }
695  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
696 }
697 
699  ClauseOrderKindAttr order,
700  OrderModifierAttr orderMod) {
701  if (orderMod)
702  p << stringifyOrderModifier(orderMod.getValue()) << ":";
703  if (order)
704  p << stringifyClauseOrderKind(order.getValue());
705 }
706 
707 template <typename ClauseTypeAttr, typename ClauseType>
708 static ParseResult
709 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
710  std::optional<OpAsmParser::UnresolvedOperand> &operand,
711  Type &operandType,
712  std::optional<ClauseType> (*symbolizeClause)(StringRef),
713  StringRef clauseName) {
714  StringRef enumStr;
715  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
716  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
717  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
718  if (parser.parseComma())
719  return failure();
720  } else {
721  return parser.emitError(parser.getCurrentLocation())
722  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
723  ;
724  }
725  }
726 
728  if (succeeded(parser.parseOperand(var))) {
729  operand = var;
730  } else {
731  return parser.emitError(parser.getCurrentLocation())
732  << "expected " << clauseName << " operand";
733  }
734 
735  if (operand.has_value()) {
736  if (parser.parseColonType(operandType))
737  return failure();
738  }
739 
740  return success();
741 }
742 
743 template <typename ClauseTypeAttr, typename ClauseType>
744 static void
746  ClauseTypeAttr prescriptiveness, Value operand,
747  mlir::Type operandType,
748  StringRef (*stringifyClauseType)(ClauseType)) {
749 
750  if (prescriptiveness)
751  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
752 
753  if (operand)
754  p << operand << ": " << operandType;
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // Parser and printer for grainsize Clause
759 //===----------------------------------------------------------------------===//
760 
761 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
762 static ParseResult
763 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
764  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
765  Type &grainsizeType) {
766  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
767  parser, grainsizeMod, grainsize, grainsizeType,
768  &symbolizeClauseGrainsizeType, "grainsize");
769 }
770 
772  ClauseGrainsizeTypeAttr grainsizeMod,
773  Value grainsize, mlir::Type grainsizeType) {
774  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
775  p, op, grainsizeMod, grainsize, grainsizeType,
776  &stringifyClauseGrainsizeType);
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // Parser and printer for num_tasks Clause
781 //===----------------------------------------------------------------------===//
782 
783 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
784 static ParseResult
785 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
786  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
787  Type &numTasksType) {
788  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
789  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
790  "num_tasks");
791 }
792 
794  ClauseNumTasksTypeAttr numTasksMod,
795  Value numTasks, mlir::Type numTasksType) {
796  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
797  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // Parsers for operations including clauses that define entry block arguments.
802 //===----------------------------------------------------------------------===//
803 
804 namespace {
805 struct MapParseArgs {
807  SmallVectorImpl<Type> &types;
809  SmallVectorImpl<Type> &types)
810  : vars(vars), types(types) {}
811 };
812 struct PrivateParseArgs {
815  ArrayAttr &syms;
816  UnitAttr &needsBarrier;
817  DenseI64ArrayAttr *mapIndices;
819  SmallVectorImpl<Type> &types, ArrayAttr &syms,
820  UnitAttr &needsBarrier,
821  DenseI64ArrayAttr *mapIndices = nullptr)
822  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
823  mapIndices(mapIndices) {}
824 };
825 
826 struct ReductionParseArgs {
828  SmallVectorImpl<Type> &types;
829  DenseBoolArrayAttr &byref;
830  ArrayAttr &syms;
831  ReductionModifierAttr *modifier;
832  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
834  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
835  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
836 };
837 
838 struct AllRegionParseArgs {
839  std::optional<MapParseArgs> hasDeviceAddrArgs;
840  std::optional<MapParseArgs> hostEvalArgs;
841  std::optional<ReductionParseArgs> inReductionArgs;
842  std::optional<MapParseArgs> mapArgs;
843  std::optional<PrivateParseArgs> privateArgs;
844  std::optional<ReductionParseArgs> reductionArgs;
845  std::optional<ReductionParseArgs> taskReductionArgs;
846  std::optional<MapParseArgs> useDeviceAddrArgs;
847  std::optional<MapParseArgs> useDevicePtrArgs;
848 };
849 } // namespace
850 
851 static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
852  return "private_barrier";
853 }
854 
855 static ParseResult parseClauseWithRegionArgs(
856  OpAsmParser &parser,
858  SmallVectorImpl<Type> &types,
859  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
860  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
861  DenseBoolArrayAttr *byref = nullptr,
862  ReductionModifierAttr *modifier = nullptr,
863  UnitAttr *needsBarrier = nullptr) {
864  SmallVector<SymbolRefAttr> symbolVec;
865  SmallVector<int64_t> mapIndicesVec;
866  SmallVector<bool> isByRefVec;
867  unsigned regionArgOffset = regionPrivateArgs.size();
868 
869  if (parser.parseLParen())
870  return failure();
871 
872  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
873  StringRef enumStr;
874  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
875  parser.parseComma())
876  return failure();
877  std::optional<ReductionModifier> enumValue =
878  symbolizeReductionModifier(enumStr);
879  if (!enumValue.has_value())
880  return failure();
881  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
882  if (!*modifier)
883  return failure();
884  }
885 
886  if (parser.parseCommaSeparatedList([&]() {
887  if (byref)
888  isByRefVec.push_back(
889  parser.parseOptionalKeyword("byref").succeeded());
890 
891  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
892  return failure();
893 
894  if (parser.parseOperand(operands.emplace_back()) ||
895  parser.parseArrow() ||
896  parser.parseArgument(regionPrivateArgs.emplace_back()))
897  return failure();
898 
899  if (mapIndices) {
900  if (parser.parseOptionalLSquare().succeeded()) {
901  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
902  parser.parseInteger(mapIndicesVec.emplace_back()) ||
903  parser.parseRSquare())
904  return failure();
905  } else {
906  mapIndicesVec.push_back(-1);
907  }
908  }
909 
910  return success();
911  }))
912  return failure();
913 
914  if (parser.parseColon())
915  return failure();
916 
917  if (parser.parseCommaSeparatedList([&]() {
918  if (parser.parseType(types.emplace_back()))
919  return failure();
920 
921  return success();
922  }))
923  return failure();
924 
925  if (operands.size() != types.size())
926  return failure();
927 
928  if (parser.parseRParen())
929  return failure();
930 
931  if (needsBarrier) {
933  .succeeded())
934  *needsBarrier = mlir::UnitAttr::get(parser.getContext());
935  }
936 
937  auto *argsBegin = regionPrivateArgs.begin();
938  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
939  argsBegin + regionArgOffset + types.size());
940  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
941  prv.type = type;
942  }
943 
944  if (symbols) {
945  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
946  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
947  }
948 
949  if (!mapIndicesVec.empty())
950  *mapIndices =
951  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
952 
953  if (byref)
954  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
955 
956  return success();
957 }
958 
959 static ParseResult parseBlockArgClause(
960  OpAsmParser &parser,
962  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
963  if (succeeded(parser.parseOptionalKeyword(keyword))) {
964  if (!mapArgs)
965  return failure();
966 
967  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
968  entryBlockArgs)))
969  return failure();
970  }
971  return success();
972 }
973 
974 static ParseResult parseBlockArgClause(
975  OpAsmParser &parser,
977  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
978  if (succeeded(parser.parseOptionalKeyword(keyword))) {
979  if (!privateArgs)
980  return failure();
981 
983  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
984  &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
985  /*modifier=*/nullptr, &privateArgs->needsBarrier)))
986  return failure();
987  }
988  return success();
989 }
990 
991 static ParseResult parseBlockArgClause(
992  OpAsmParser &parser,
994  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
995  if (succeeded(parser.parseOptionalKeyword(keyword))) {
996  if (!reductionArgs)
997  return failure();
999  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1000  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1001  reductionArgs->modifier)))
1002  return failure();
1003  }
1004  return success();
1005 }
1006 
1007 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1008  AllRegionParseArgs args) {
1010 
1011  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1012  args.hasDeviceAddrArgs)))
1013  return parser.emitError(parser.getCurrentLocation())
1014  << "invalid `has_device_addr` format";
1015 
1016  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1017  args.hostEvalArgs)))
1018  return parser.emitError(parser.getCurrentLocation())
1019  << "invalid `host_eval` format";
1020 
1021  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1022  args.inReductionArgs)))
1023  return parser.emitError(parser.getCurrentLocation())
1024  << "invalid `in_reduction` format";
1025 
1026  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1027  args.mapArgs)))
1028  return parser.emitError(parser.getCurrentLocation())
1029  << "invalid `map_entries` format";
1030 
1031  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1032  args.privateArgs)))
1033  return parser.emitError(parser.getCurrentLocation())
1034  << "invalid `private` format";
1035 
1036  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1037  args.reductionArgs)))
1038  return parser.emitError(parser.getCurrentLocation())
1039  << "invalid `reduction` format";
1040 
1041  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1042  args.taskReductionArgs)))
1043  return parser.emitError(parser.getCurrentLocation())
1044  << "invalid `task_reduction` format";
1045 
1046  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1047  args.useDeviceAddrArgs)))
1048  return parser.emitError(parser.getCurrentLocation())
1049  << "invalid `use_device_addr` format";
1050 
1051  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1052  args.useDevicePtrArgs)))
1053  return parser.emitError(parser.getCurrentLocation())
1054  << "invalid `use_device_addr` format";
1055 
1056  return parser.parseRegion(region, entryBlockArgs);
1057 }
1058 
1059 // These parseXyz functions correspond to the custom<Xyz> definitions
1060 // in the .td file(s).
1061 static ParseResult parseTargetOpRegion(
1062  OpAsmParser &parser, Region &region,
1064  SmallVectorImpl<Type> &hasDeviceAddrTypes,
1066  SmallVectorImpl<Type> &hostEvalTypes,
1068  SmallVectorImpl<Type> &inReductionTypes,
1069  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1071  SmallVectorImpl<Type> &mapTypes,
1073  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1074  UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1075  AllRegionParseArgs args;
1076  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1077  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1078  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1079  inReductionByref, inReductionSyms);
1080  args.mapArgs.emplace(mapVars, mapTypes);
1081  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1082  privateNeedsBarrier, &privateMaps);
1083  return parseBlockArgRegion(parser, region, args);
1084 }
1085 
1087  OpAsmParser &parser, Region &region,
1089  SmallVectorImpl<Type> &inReductionTypes,
1090  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1092  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1093  UnitAttr &privateNeedsBarrier) {
1094  AllRegionParseArgs args;
1095  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1096  inReductionByref, inReductionSyms);
1097  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1098  privateNeedsBarrier);
1099  return parseBlockArgRegion(parser, region, args);
1100 }
1101 
1103  OpAsmParser &parser, Region &region,
1105  SmallVectorImpl<Type> &inReductionTypes,
1106  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1108  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1109  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1111  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1112  ArrayAttr &reductionSyms) {
1113  AllRegionParseArgs args;
1114  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1115  inReductionByref, inReductionSyms);
1116  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1117  privateNeedsBarrier);
1118  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1119  reductionSyms, &reductionMod);
1120  return parseBlockArgRegion(parser, region, args);
1121 }
1122 
1123 static ParseResult parsePrivateRegion(
1124  OpAsmParser &parser, Region &region,
1126  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1127  UnitAttr &privateNeedsBarrier) {
1128  AllRegionParseArgs args;
1129  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1130  privateNeedsBarrier);
1131  return parseBlockArgRegion(parser, region, args);
1132 }
1133 
1134 static ParseResult parsePrivateReductionRegion(
1135  OpAsmParser &parser, Region &region,
1137  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1138  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1140  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1141  ArrayAttr &reductionSyms) {
1142  AllRegionParseArgs args;
1143  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144  privateNeedsBarrier);
1145  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1146  reductionSyms, &reductionMod);
1147  return parseBlockArgRegion(parser, region, args);
1148 }
1149 
1150 static ParseResult parseTaskReductionRegion(
1151  OpAsmParser &parser, Region &region,
1153  SmallVectorImpl<Type> &taskReductionTypes,
1154  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1155  AllRegionParseArgs args;
1156  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1157  taskReductionByref, taskReductionSyms);
1158  return parseBlockArgRegion(parser, region, args);
1159 }
1160 
1162  OpAsmParser &parser, Region &region,
1164  SmallVectorImpl<Type> &useDeviceAddrTypes,
1166  SmallVectorImpl<Type> &useDevicePtrTypes) {
1167  AllRegionParseArgs args;
1168  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1169  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1170  return parseBlockArgRegion(parser, region, args);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // Printers for operations including clauses that define entry block arguments.
1175 //===----------------------------------------------------------------------===//
1176 
1177 namespace {
1178 struct MapPrintArgs {
1179  ValueRange vars;
1180  TypeRange types;
1181  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1182 };
1183 struct PrivatePrintArgs {
1184  ValueRange vars;
1185  TypeRange types;
1186  ArrayAttr syms;
1187  UnitAttr needsBarrier;
1188  DenseI64ArrayAttr mapIndices;
1189  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1190  UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1191  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1192  mapIndices(mapIndices) {}
1193 };
1194 struct ReductionPrintArgs {
1195  ValueRange vars;
1196  TypeRange types;
1197  DenseBoolArrayAttr byref;
1198  ArrayAttr syms;
1199  ReductionModifierAttr modifier;
1200  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1201  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1202  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1203 };
1204 struct AllRegionPrintArgs {
1205  std::optional<MapPrintArgs> hasDeviceAddrArgs;
1206  std::optional<MapPrintArgs> hostEvalArgs;
1207  std::optional<ReductionPrintArgs> inReductionArgs;
1208  std::optional<MapPrintArgs> mapArgs;
1209  std::optional<PrivatePrintArgs> privateArgs;
1210  std::optional<ReductionPrintArgs> reductionArgs;
1211  std::optional<ReductionPrintArgs> taskReductionArgs;
1212  std::optional<MapPrintArgs> useDeviceAddrArgs;
1213  std::optional<MapPrintArgs> useDevicePtrArgs;
1214 };
1215 } // namespace
1216 
1218  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1219  ValueRange argsSubrange, ValueRange operands, TypeRange types,
1220  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1221  DenseBoolArrayAttr byref = nullptr,
1222  ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1223  if (argsSubrange.empty())
1224  return;
1225 
1226  p << clauseName << "(";
1227 
1228  if (modifier)
1229  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1230 
1231  if (!symbols) {
1232  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1233  symbols = ArrayAttr::get(ctx, values);
1234  }
1235 
1236  if (!mapIndices) {
1237  llvm::SmallVector<int64_t> values(operands.size(), -1);
1238  mapIndices = DenseI64ArrayAttr::get(ctx, values);
1239  }
1240 
1241  if (!byref) {
1242  mlir::SmallVector<bool> values(operands.size(), false);
1243  byref = DenseBoolArrayAttr::get(ctx, values);
1244  }
1245 
1246  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1247  mapIndices.asArrayRef(),
1248  byref.asArrayRef()),
1249  p, [&p](auto t) {
1250  auto [op, arg, sym, map, isByRef] = t;
1251  if (isByRef)
1252  p << "byref ";
1253  if (sym)
1254  p << sym << " ";
1255 
1256  p << op << " -> " << arg;
1257 
1258  if (map != -1)
1259  p << " [map_idx=" << map << "]";
1260  });
1261  p << " : ";
1262  llvm::interleaveComma(types, p);
1263  p << ") ";
1264 
1265  if (needsBarrier)
1266  p << getPrivateNeedsBarrierSpelling() << " ";
1267 }
1268 
1270  StringRef clauseName, ValueRange argsSubrange,
1271  std::optional<MapPrintArgs> mapArgs) {
1272  if (mapArgs)
1273  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1274  mapArgs->types);
1275 }
1276 
1278  StringRef clauseName, ValueRange argsSubrange,
1279  std::optional<PrivatePrintArgs> privateArgs) {
1280  if (privateArgs)
1282  p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1283  privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1284  /*modifier=*/nullptr, privateArgs->needsBarrier);
1285 }
1286 
1287 static void
1288 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1289  ValueRange argsSubrange,
1290  std::optional<ReductionPrintArgs> reductionArgs) {
1291  if (reductionArgs)
1292  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1293  reductionArgs->vars, reductionArgs->types,
1294  reductionArgs->syms, /*mapIndices=*/nullptr,
1295  reductionArgs->byref, reductionArgs->modifier);
1296 }
1297 
1298 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1299  const AllRegionPrintArgs &args) {
1300  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1301  MLIRContext *ctx = op->getContext();
1302 
1303  printBlockArgClause(p, ctx, "has_device_addr",
1304  iface.getHasDeviceAddrBlockArgs(),
1305  args.hasDeviceAddrArgs);
1306  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1307  args.hostEvalArgs);
1308  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1309  args.inReductionArgs);
1310  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1311  args.mapArgs);
1312  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1313  args.privateArgs);
1314  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1315  args.reductionArgs);
1316  printBlockArgClause(p, ctx, "task_reduction",
1317  iface.getTaskReductionBlockArgs(),
1318  args.taskReductionArgs);
1319  printBlockArgClause(p, ctx, "use_device_addr",
1320  iface.getUseDeviceAddrBlockArgs(),
1321  args.useDeviceAddrArgs);
1322  printBlockArgClause(p, ctx, "use_device_ptr",
1323  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1324 
1325  p.printRegion(region, /*printEntryBlockArgs=*/false);
1326 }
1327 
1328 // These parseXyz functions correspond to the custom<Xyz> definitions
1329 // in the .td file(s).
1331  OpAsmPrinter &p, Operation *op, Region &region,
1332  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1333  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1334  ValueRange inReductionVars, TypeRange inReductionTypes,
1335  DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1336  ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1337  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1338  DenseI64ArrayAttr privateMaps) {
1339  AllRegionPrintArgs args;
1340  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1341  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1342  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1343  inReductionByref, inReductionSyms);
1344  args.mapArgs.emplace(mapVars, mapTypes);
1345  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1346  privateNeedsBarrier, privateMaps);
1347  printBlockArgRegion(p, op, region, args);
1348 }
1349 
1351  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1352  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1353  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1354  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1355  AllRegionPrintArgs args;
1356  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1357  inReductionByref, inReductionSyms);
1358  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1359  privateNeedsBarrier,
1360  /*mapIndices=*/nullptr);
1361  printBlockArgRegion(p, op, region, args);
1362 }
1363 
1365  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1366  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1367  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1368  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1369  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1370  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1371  ArrayAttr reductionSyms) {
1372  AllRegionPrintArgs args;
1373  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1374  inReductionByref, inReductionSyms);
1375  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1376  privateNeedsBarrier,
1377  /*mapIndices=*/nullptr);
1378  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1379  reductionSyms, reductionMod);
1380  printBlockArgRegion(p, op, region, args);
1381 }
1382 
1383 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1384  ValueRange privateVars, TypeRange privateTypes,
1385  ArrayAttr privateSyms,
1386  UnitAttr privateNeedsBarrier) {
1387  AllRegionPrintArgs args;
1388  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1389  privateNeedsBarrier,
1390  /*mapIndices=*/nullptr);
1391  printBlockArgRegion(p, op, region, args);
1392 }
1393 
1395  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1396  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1397  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1398  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1399  ArrayAttr reductionSyms) {
1400  AllRegionPrintArgs args;
1401  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1402  privateNeedsBarrier,
1403  /*mapIndices=*/nullptr);
1404  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1405  reductionSyms, reductionMod);
1406  printBlockArgRegion(p, op, region, args);
1407 }
1408 
1410  Region &region,
1411  ValueRange taskReductionVars,
1412  TypeRange taskReductionTypes,
1413  DenseBoolArrayAttr taskReductionByref,
1414  ArrayAttr taskReductionSyms) {
1415  AllRegionPrintArgs args;
1416  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1417  taskReductionByref, taskReductionSyms);
1418  printBlockArgRegion(p, op, region, args);
1419 }
1420 
1422  Region &region,
1423  ValueRange useDeviceAddrVars,
1424  TypeRange useDeviceAddrTypes,
1425  ValueRange useDevicePtrVars,
1426  TypeRange useDevicePtrTypes) {
1427  AllRegionPrintArgs args;
1428  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1429  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1430  printBlockArgRegion(p, op, region, args);
1431 }
1432 
1433 /// Verifies Reduction Clause
1434 static LogicalResult
1435 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1436  OperandRange reductionVars,
1437  std::optional<ArrayRef<bool>> reductionByref) {
1438  if (!reductionVars.empty()) {
1439  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1440  return op->emitOpError()
1441  << "expected as many reduction symbol references "
1442  "as reduction variables";
1443  if (reductionByref && reductionByref->size() != reductionVars.size())
1444  return op->emitError() << "expected as many reduction variable by "
1445  "reference attributes as reduction variables";
1446  } else {
1447  if (reductionSyms)
1448  return op->emitOpError() << "unexpected reduction symbol references";
1449  return success();
1450  }
1451 
1452  // TODO: The followings should be done in
1453  // SymbolUserOpInterface::verifySymbolUses.
1454  DenseSet<Value> accumulators;
1455  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1456  Value accum = std::get<0>(args);
1457 
1458  if (!accumulators.insert(accum).second)
1459  return op->emitOpError() << "accumulator variable used more than once";
1460 
1461  Type varType = accum.getType();
1462  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1463  auto decl =
1464  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1465  if (!decl)
1466  return op->emitOpError() << "expected symbol reference " << symbolRef
1467  << " to point to a reduction declaration";
1468 
1469  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1470  return op->emitOpError()
1471  << "expected accumulator (" << varType
1472  << ") to be the same type as reduction declaration ("
1473  << decl.getAccumulatorType() << ")";
1474  }
1475 
1476  return success();
1477 }
1478 
1479 //===----------------------------------------------------------------------===//
1480 // Parser, printer and verifier for Copyprivate
1481 //===----------------------------------------------------------------------===//
1482 
1483 /// copyprivate-entry-list ::= copyprivate-entry
1484 /// | copyprivate-entry-list `,` copyprivate-entry
1485 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1486 static ParseResult parseCopyprivate(
1487  OpAsmParser &parser,
1489  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1491  if (failed(parser.parseCommaSeparatedList([&]() {
1492  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1493  parser.parseArrow() ||
1494  parser.parseAttribute(symsVec.emplace_back()) ||
1495  parser.parseColonType(copyprivateTypes.emplace_back()))
1496  return failure();
1497  return success();
1498  })))
1499  return failure();
1500  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1501  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1502  return success();
1503 }
1504 
1505 /// Print Copyprivate clause
1507  OperandRange copyprivateVars,
1508  TypeRange copyprivateTypes,
1509  std::optional<ArrayAttr> copyprivateSyms) {
1510  if (!copyprivateSyms.has_value())
1511  return;
1512  llvm::interleaveComma(
1513  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1514  [&](const auto &args) {
1515  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1516  << std::get<2>(args);
1517  });
1518 }
1519 
1520 /// Verifies CopyPrivate Clause
1521 static LogicalResult
1523  std::optional<ArrayAttr> copyprivateSyms) {
1524  size_t copyprivateSymsSize =
1525  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1526  if (copyprivateSymsSize != copyprivateVars.size())
1527  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1528  << copyprivateVars.size()
1529  << ") and functions (= " << copyprivateSymsSize
1530  << "), both must be equal";
1531  if (!copyprivateSyms.has_value())
1532  return success();
1533 
1534  for (auto copyprivateVarAndSym :
1535  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1536  auto symbolRef =
1537  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1538  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1539  funcOp;
1540  if (mlir::func::FuncOp mlirFuncOp =
1541  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1542  symbolRef))
1543  funcOp = mlirFuncOp;
1544  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1545  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1546  op, symbolRef))
1547  funcOp = llvmFuncOp;
1548 
1549  auto getNumArguments = [&] {
1550  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1551  };
1552 
1553  auto getArgumentType = [&](unsigned i) {
1554  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1555  *funcOp);
1556  };
1557 
1558  if (!funcOp)
1559  return op->emitOpError() << "expected symbol reference " << symbolRef
1560  << " to point to a copy function";
1561 
1562  if (getNumArguments() != 2)
1563  return op->emitOpError()
1564  << "expected copy function " << symbolRef << " to have 2 operands";
1565 
1566  Type argTy = getArgumentType(0);
1567  if (argTy != getArgumentType(1))
1568  return op->emitOpError() << "expected copy function " << symbolRef
1569  << " arguments to have the same type";
1570 
1571  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1572  if (argTy != varType)
1573  return op->emitOpError()
1574  << "expected copy function arguments' type (" << argTy
1575  << ") to be the same as copyprivate variable's type (" << varType
1576  << ")";
1577  }
1578 
1579  return success();
1580 }
1581 
1582 //===----------------------------------------------------------------------===//
1583 // Parser, printer and verifier for DependVarList
1584 //===----------------------------------------------------------------------===//
1585 
1586 /// depend-entry-list ::= depend-entry
1587 /// | depend-entry-list `,` depend-entry
1588 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1589 static ParseResult
1592  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1594  if (failed(parser.parseCommaSeparatedList([&]() {
1595  StringRef keyword;
1596  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1597  parser.parseOperand(dependVars.emplace_back()) ||
1598  parser.parseColonType(dependTypes.emplace_back()))
1599  return failure();
1600  if (std::optional<ClauseTaskDepend> keywordDepend =
1601  (symbolizeClauseTaskDepend(keyword)))
1602  kindsVec.emplace_back(
1603  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1604  else
1605  return failure();
1606  return success();
1607  })))
1608  return failure();
1609  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1610  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1611  return success();
1612 }
1613 
1614 /// Print Depend clause
1616  OperandRange dependVars, TypeRange dependTypes,
1617  std::optional<ArrayAttr> dependKinds) {
1618 
1619  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1620  if (i != 0)
1621  p << ", ";
1622  p << stringifyClauseTaskDepend(
1623  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1624  .getValue())
1625  << " -> " << dependVars[i] << " : " << dependTypes[i];
1626  }
1627 }
1628 
1629 /// Verifies Depend clause
1630 static LogicalResult verifyDependVarList(Operation *op,
1631  std::optional<ArrayAttr> dependKinds,
1632  OperandRange dependVars) {
1633  if (!dependVars.empty()) {
1634  if (!dependKinds || dependKinds->size() != dependVars.size())
1635  return op->emitOpError() << "expected as many depend values"
1636  " as depend variables";
1637  } else {
1638  if (dependKinds && !dependKinds->empty())
1639  return op->emitOpError() << "unexpected depend values";
1640  return success();
1641  }
1642 
1643  return success();
1644 }
1645 
1646 //===----------------------------------------------------------------------===//
1647 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1648 //===----------------------------------------------------------------------===//
1649 
1650 /// Parses a Synchronization Hint clause. The value of hint is an integer
1651 /// which is a combination of different hints from `omp_sync_hint_t`.
1652 ///
1653 /// hint-clause = `hint` `(` hint-value `)`
1654 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1655  IntegerAttr &hintAttr) {
1656  StringRef hintKeyword;
1657  int64_t hint = 0;
1658  if (succeeded(parser.parseOptionalKeyword("none"))) {
1659  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1660  return success();
1661  }
1662  auto parseKeyword = [&]() -> ParseResult {
1663  if (failed(parser.parseKeyword(&hintKeyword)))
1664  return failure();
1665  if (hintKeyword == "uncontended")
1666  hint |= 1;
1667  else if (hintKeyword == "contended")
1668  hint |= 2;
1669  else if (hintKeyword == "nonspeculative")
1670  hint |= 4;
1671  else if (hintKeyword == "speculative")
1672  hint |= 8;
1673  else
1674  return parser.emitError(parser.getCurrentLocation())
1675  << hintKeyword << " is not a valid hint";
1676  return success();
1677  };
1678  if (parser.parseCommaSeparatedList(parseKeyword))
1679  return failure();
1680  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1681  return success();
1682 }
1683 
1684 /// Prints a Synchronization Hint clause
1686  IntegerAttr hintAttr) {
1687  int64_t hint = hintAttr.getInt();
1688 
1689  if (hint == 0) {
1690  p << "none";
1691  return;
1692  }
1693 
1694  // Helper function to get n-th bit from the right end of `value`
1695  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1696 
1697  bool uncontended = bitn(hint, 0);
1698  bool contended = bitn(hint, 1);
1699  bool nonspeculative = bitn(hint, 2);
1700  bool speculative = bitn(hint, 3);
1701 
1702  SmallVector<StringRef> hints;
1703  if (uncontended)
1704  hints.push_back("uncontended");
1705  if (contended)
1706  hints.push_back("contended");
1707  if (nonspeculative)
1708  hints.push_back("nonspeculative");
1709  if (speculative)
1710  hints.push_back("speculative");
1711 
1712  llvm::interleaveComma(hints, p);
1713 }
1714 
1715 /// Verifies a synchronization hint clause
1716 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1717 
1718  // Helper function to get n-th bit from the right end of `value`
1719  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1720 
1721  bool uncontended = bitn(hint, 0);
1722  bool contended = bitn(hint, 1);
1723  bool nonspeculative = bitn(hint, 2);
1724  bool speculative = bitn(hint, 3);
1725 
1726  if (uncontended && contended)
1727  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1728  "omp_sync_hint_contended cannot be combined";
1729  if (nonspeculative && speculative)
1730  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1731  "omp_sync_hint_speculative cannot be combined.";
1732  return success();
1733 }
1734 
1735 //===----------------------------------------------------------------------===//
1736 // Parser, printer and verifier for Target
1737 //===----------------------------------------------------------------------===//
1738 
1739 // Helper function to get bitwise AND of `value` and 'flag' then return it as a
1740 // boolean
1741 static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1742  return (value & flag) == flag;
1743 }
1744 
1745 /// Parses a map_entries map type from a string format back into its numeric
1746 /// value.
1747 ///
1748 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1749 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1750 static ParseResult parseMapClause(OpAsmParser &parser,
1751  ClauseMapFlagsAttr &mapType) {
1752  ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1753  // This simply verifies the correct keyword is read in, the
1754  // keyword itself is stored inside of the operation
1755  auto parseTypeAndMod = [&]() -> ParseResult {
1756  StringRef mapTypeMod;
1757  if (parser.parseKeyword(&mapTypeMod))
1758  return failure();
1759 
1760  if (mapTypeMod == "always")
1761  mapTypeBits |= ClauseMapFlags::always;
1762 
1763  if (mapTypeMod == "implicit")
1764  mapTypeBits |= ClauseMapFlags::implicit;
1765 
1766  if (mapTypeMod == "ompx_hold")
1767  mapTypeBits |= ClauseMapFlags::ompx_hold;
1768 
1769  if (mapTypeMod == "close")
1770  mapTypeBits |= ClauseMapFlags::close;
1771 
1772  if (mapTypeMod == "present")
1773  mapTypeBits |= ClauseMapFlags::present;
1774 
1775  if (mapTypeMod == "to")
1776  mapTypeBits |= ClauseMapFlags::to;
1777 
1778  if (mapTypeMod == "from")
1779  mapTypeBits |= ClauseMapFlags::from;
1780 
1781  if (mapTypeMod == "tofrom")
1782  mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1783 
1784  if (mapTypeMod == "delete")
1785  mapTypeBits |= ClauseMapFlags::del;
1786 
1787  if (mapTypeMod == "storage")
1788  mapTypeBits |= ClauseMapFlags::storage;
1789 
1790  if (mapTypeMod == "return_param")
1791  mapTypeBits |= ClauseMapFlags::return_param;
1792 
1793  if (mapTypeMod == "private")
1794  mapTypeBits |= ClauseMapFlags::priv;
1795 
1796  if (mapTypeMod == "literal")
1797  mapTypeBits |= ClauseMapFlags::literal;
1798 
1799  if (mapTypeMod == "attach")
1800  mapTypeBits |= ClauseMapFlags::attach;
1801 
1802  if (mapTypeMod == "attach_always")
1803  mapTypeBits |= ClauseMapFlags::attach_always;
1804 
1805  if (mapTypeMod == "attach_none")
1806  mapTypeBits |= ClauseMapFlags::attach_none;
1807 
1808  if (mapTypeMod == "attach_auto")
1809  mapTypeBits |= ClauseMapFlags::attach_auto;
1810 
1811  if (mapTypeMod == "ref_ptr")
1812  mapTypeBits |= ClauseMapFlags::ref_ptr;
1813 
1814  if (mapTypeMod == "ref_ptee")
1815  mapTypeBits |= ClauseMapFlags::ref_ptee;
1816 
1817  if (mapTypeMod == "ref_ptr_ptee")
1818  mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1819 
1820  return success();
1821  };
1822 
1823  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1824  return failure();
1825 
1826  mapType =
1827  parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1828 
1829  return success();
1830 }
1831 
1832 /// Prints a map_entries map type from its numeric value out into its string
1833 /// format.
1835  ClauseMapFlagsAttr mapType) {
1837  ClauseMapFlags mapFlags = mapType.getValue();
1838 
1839  // handling of always, close, present placed at the beginning of the string
1840  // to aid readability
1841  if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
1842  mapTypeStrs.push_back("always");
1843  if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
1844  mapTypeStrs.push_back("implicit");
1845  if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
1846  mapTypeStrs.push_back("ompx_hold");
1847  if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
1848  mapTypeStrs.push_back("close");
1849  if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
1850  mapTypeStrs.push_back("present");
1851 
1852  // special handling of to/from/tofrom/delete and release/alloc, release +
1853  // alloc are the abscense of one of the other flags, whereas tofrom requires
1854  // both the to and from flag to be set.
1855  bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
1856  bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
1857 
1858  if (to && from)
1859  mapTypeStrs.push_back("tofrom");
1860  else if (from)
1861  mapTypeStrs.push_back("from");
1862  else if (to)
1863  mapTypeStrs.push_back("to");
1864 
1865  if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
1866  mapTypeStrs.push_back("delete");
1867  if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
1868  mapTypeStrs.push_back("return_param");
1869  if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
1870  mapTypeStrs.push_back("storage");
1871  if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
1872  mapTypeStrs.push_back("private");
1873  if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
1874  mapTypeStrs.push_back("literal");
1875  if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
1876  mapTypeStrs.push_back("attach");
1877  if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
1878  mapTypeStrs.push_back("attach_always");
1879  if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none))
1880  mapTypeStrs.push_back("attach_none");
1881  if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
1882  mapTypeStrs.push_back("attach_auto");
1883  if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
1884  mapTypeStrs.push_back("ref_ptr");
1885  if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
1886  mapTypeStrs.push_back("ref_ptee");
1887  if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
1888  mapTypeStrs.push_back("ref_ptr_ptee");
1889  if (mapFlags == ClauseMapFlags::none)
1890  mapTypeStrs.push_back("none");
1891 
1892  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1893  p << mapTypeStrs[i];
1894  if (i + 1 < mapTypeStrs.size()) {
1895  p << ", ";
1896  }
1897  }
1898 }
1899 
1900 static ParseResult parseMembersIndex(OpAsmParser &parser,
1901  ArrayAttr &membersIdx) {
1902  SmallVector<Attribute> values, memberIdxs;
1903 
1904  auto parseIndices = [&]() -> ParseResult {
1905  int64_t value;
1906  if (parser.parseInteger(value))
1907  return failure();
1908  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1909  APInt(64, value, /*isSigned=*/false)));
1910  return success();
1911  };
1912 
1913  do {
1914  if (failed(parser.parseLSquare()))
1915  return failure();
1916 
1917  if (parser.parseCommaSeparatedList(parseIndices))
1918  return failure();
1919 
1920  if (failed(parser.parseRSquare()))
1921  return failure();
1922 
1923  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1924  values.clear();
1925  } while (succeeded(parser.parseOptionalComma()));
1926 
1927  if (!memberIdxs.empty())
1928  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1929 
1930  return success();
1931 }
1932 
1933 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1934  ArrayAttr membersIdx) {
1935  if (!membersIdx)
1936  return;
1937 
1938  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1939  p << "[";
1940  auto memberIdx = cast<ArrayAttr>(v);
1941  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1942  p << cast<IntegerAttr>(v2).getInt();
1943  });
1944  p << "]";
1945  });
1946 }
1947 
1949  VariableCaptureKindAttr mapCaptureType) {
1950  std::string typeCapStr;
1951  llvm::raw_string_ostream typeCap(typeCapStr);
1952  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1953  typeCap << "ByRef";
1954  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1955  typeCap << "ByCopy";
1956  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1957  typeCap << "VLAType";
1958  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1959  typeCap << "This";
1960  p << typeCapStr;
1961 }
1962 
1963 static ParseResult parseCaptureType(OpAsmParser &parser,
1964  VariableCaptureKindAttr &mapCaptureType) {
1965  StringRef mapCaptureKey;
1966  if (parser.parseKeyword(&mapCaptureKey))
1967  return failure();
1968 
1969  if (mapCaptureKey == "This")
1970  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1971  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1972  if (mapCaptureKey == "ByRef")
1973  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1974  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1975  if (mapCaptureKey == "ByCopy")
1976  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1977  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1978  if (mapCaptureKey == "VLAType")
1979  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1980  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1981 
1982  return success();
1983 }
1984 
1985 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1988 
1989  for (auto mapOp : mapVars) {
1990  if (!mapOp.getDefiningOp())
1991  return emitError(op->getLoc(), "missing map operation");
1992 
1993  if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1994  mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
1995 
1996  bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
1997  bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
1998  bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
1999 
2000  bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2001  bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2002  bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2003 
2004  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2005  return emitError(op->getLoc(),
2006  "to, from, tofrom and alloc map types are permitted");
2007 
2008  if (isa<TargetEnterDataOp>(op) && (from || del))
2009  return emitError(op->getLoc(), "to and alloc map types are permitted");
2010 
2011  if (isa<TargetExitDataOp>(op) && to)
2012  return emitError(op->getLoc(),
2013  "from, release and delete map types are permitted");
2014 
2015  if (isa<TargetUpdateOp>(op)) {
2016  if (del) {
2017  return emitError(op->getLoc(),
2018  "at least one of to or from map types must be "
2019  "specified, other map types are not permitted");
2020  }
2021 
2022  if (!to && !from) {
2023  return emitError(op->getLoc(),
2024  "at least one of to or from map types must be "
2025  "specified, other map types are not permitted");
2026  }
2027 
2028  auto updateVar = mapInfoOp.getVarPtr();
2029 
2030  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2031  (from && updateToVars.contains(updateVar))) {
2032  return emitError(
2033  op->getLoc(),
2034  "either to or from map types can be specified, not both");
2035  }
2036 
2037  if (always || close || implicit) {
2038  return emitError(
2039  op->getLoc(),
2040  "present, mapper and iterator map type modifiers are permitted");
2041  }
2042 
2043  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2044  }
2045  } else if (!isa<DeclareMapperInfoOp>(op)) {
2046  return emitError(op->getLoc(),
2047  "map argument is not a map entry operation");
2048  }
2049  }
2050 
2051  return success();
2052 }
2053 
2054 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2055  std::optional<DenseI64ArrayAttr> privateMapIndices =
2056  targetOp.getPrivateMapsAttr();
2057 
2058  // None of the private operands are mapped.
2059  if (!privateMapIndices.has_value() || !privateMapIndices.value())
2060  return success();
2061 
2062  OperandRange privateVars = targetOp.getPrivateVars();
2063 
2064  if (privateMapIndices.value().size() !=
2065  static_cast<int64_t>(privateVars.size()))
2066  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2067  "`private_maps` attribute mismatch");
2068 
2069  return success();
2070 }
2071 
2072 //===----------------------------------------------------------------------===//
2073 // MapInfoOp
2074 //===----------------------------------------------------------------------===//
2075 
2076 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2077  StringRef clauseName,
2078  OperandRange vars) {
2079  for (Value var : vars)
2080  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2081  return op->emitOpError()
2082  << "'" << clauseName
2083  << "' arguments must be defined by 'omp.map.info' ops";
2084  return success();
2085 }
2086 
2087 LogicalResult MapInfoOp::verify() {
2088  if (getMapperId() &&
2089  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
2090  *this, getMapperIdAttr())) {
2091  return emitError("invalid mapper id");
2092  }
2093 
2094  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2095  return failure();
2096 
2097  return success();
2098 }
2099 
2100 //===----------------------------------------------------------------------===//
2101 // TargetDataOp
2102 //===----------------------------------------------------------------------===//
2103 
2104 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2105  const TargetDataOperands &clauses) {
2106  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2107  clauses.mapVars, clauses.useDeviceAddrVars,
2108  clauses.useDevicePtrVars);
2109 }
2110 
2111 LogicalResult TargetDataOp::verify() {
2112  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2113  getUseDeviceAddrVars().empty()) {
2114  return ::emitError(this->getLoc(),
2115  "At least one of map, use_device_ptr_vars, or "
2116  "use_device_addr_vars operand must be present");
2117  }
2118 
2119  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2120  getUseDevicePtrVars())))
2121  return failure();
2122 
2123  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2124  getUseDeviceAddrVars())))
2125  return failure();
2126 
2127  return verifyMapClause(*this, getMapVars());
2128 }
2129 
2130 //===----------------------------------------------------------------------===//
2131 // TargetEnterDataOp
2132 //===----------------------------------------------------------------------===//
2133 
2134 void TargetEnterDataOp::build(
2135  OpBuilder &builder, OperationState &state,
2136  const TargetEnterExitUpdateDataOperands &clauses) {
2137  MLIRContext *ctx = builder.getContext();
2138  TargetEnterDataOp::build(builder, state,
2139  makeArrayAttr(ctx, clauses.dependKinds),
2140  clauses.dependVars, clauses.device, clauses.ifExpr,
2141  clauses.mapVars, clauses.nowait);
2142 }
2143 
2144 LogicalResult TargetEnterDataOp::verify() {
2145  LogicalResult verifyDependVars =
2146  verifyDependVarList(*this, getDependKinds(), getDependVars());
2147  return failed(verifyDependVars) ? verifyDependVars
2148  : verifyMapClause(*this, getMapVars());
2149 }
2150 
2151 //===----------------------------------------------------------------------===//
2152 // TargetExitDataOp
2153 //===----------------------------------------------------------------------===//
2154 
2155 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2156  const TargetEnterExitUpdateDataOperands &clauses) {
2157  MLIRContext *ctx = builder.getContext();
2158  TargetExitDataOp::build(builder, state,
2159  makeArrayAttr(ctx, clauses.dependKinds),
2160  clauses.dependVars, clauses.device, clauses.ifExpr,
2161  clauses.mapVars, clauses.nowait);
2162 }
2163 
2164 LogicalResult TargetExitDataOp::verify() {
2165  LogicalResult verifyDependVars =
2166  verifyDependVarList(*this, getDependKinds(), getDependVars());
2167  return failed(verifyDependVars) ? verifyDependVars
2168  : verifyMapClause(*this, getMapVars());
2169 }
2170 
2171 //===----------------------------------------------------------------------===//
2172 // TargetUpdateOp
2173 //===----------------------------------------------------------------------===//
2174 
2175 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2176  const TargetEnterExitUpdateDataOperands &clauses) {
2177  MLIRContext *ctx = builder.getContext();
2178  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2179  clauses.dependVars, clauses.device, clauses.ifExpr,
2180  clauses.mapVars, clauses.nowait);
2181 }
2182 
2183 LogicalResult TargetUpdateOp::verify() {
2184  LogicalResult verifyDependVars =
2185  verifyDependVarList(*this, getDependKinds(), getDependVars());
2186  return failed(verifyDependVars) ? verifyDependVars
2187  : verifyMapClause(*this, getMapVars());
2188 }
2189 
2190 //===----------------------------------------------------------------------===//
2191 // TargetOp
2192 //===----------------------------------------------------------------------===//
2193 
2194 void TargetOp::build(OpBuilder &builder, OperationState &state,
2195  const TargetOperands &clauses) {
2196  MLIRContext *ctx = builder.getContext();
2197  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2198  // inReductionByref, inReductionSyms.
2199  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2200  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
2201  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2202  clauses.hostEvalVars, clauses.ifExpr,
2203  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2204  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
2205  clauses.mapVars, clauses.nowait, clauses.privateVars,
2206  makeArrayAttr(ctx, clauses.privateSyms),
2207  clauses.privateNeedsBarrier, clauses.threadLimit,
2208  /*private_maps=*/nullptr);
2209 }
2210 
2211 LogicalResult TargetOp::verify() {
2212  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
2213  return failure();
2214 
2215  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2216  getHasDeviceAddrVars())))
2217  return failure();
2218 
2219  if (failed(verifyMapClause(*this, getMapVars())))
2220  return failure();
2221 
2222  return verifyPrivateVarsMapping(*this);
2223 }
2224 
2225 LogicalResult TargetOp::verifyRegions() {
2226  auto teamsOps = getOps<TeamsOp>();
2227  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2228  return emitError("target containing multiple 'omp.teams' nested ops");
2229 
2230  // Check that host_eval values are only used in legal ways.
2231  Operation *capturedOp = getInnermostCapturedOmpOp();
2232  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2233  for (Value hostEvalArg :
2234  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2235  for (Operation *user : hostEvalArg.getUsers()) {
2236  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2237  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2238  teamsOp.getNumTeamsUpper(),
2239  teamsOp.getThreadLimit()},
2240  hostEvalArg))
2241  continue;
2242 
2243  return emitOpError() << "host_eval argument only legal as 'num_teams' "
2244  "and 'thread_limit' in 'omp.teams'";
2245  }
2246  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2247  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2248  parallelOp->isAncestor(capturedOp) &&
2249  hostEvalArg == parallelOp.getNumThreads())
2250  continue;
2251 
2252  return emitOpError()
2253  << "host_eval argument only legal as 'num_threads' in "
2254  "'omp.parallel' when representing target SPMD";
2255  }
2256  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2257  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2258  loopNestOp.getOperation() == capturedOp &&
2259  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2260  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2261  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2262  continue;
2263 
2264  return emitOpError() << "host_eval argument only legal as loop bounds "
2265  "and steps in 'omp.loop_nest' when trip count "
2266  "must be evaluated in the host";
2267  }
2268 
2269  return emitOpError() << "host_eval argument illegal use in '"
2270  << user->getName() << "' operation";
2271  }
2272  }
2273  return success();
2274 }
2275 
2276 static Operation *
2277 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2278  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2279  assert(rootOp && "expected valid operation");
2280 
2281  Dialect *ompDialect = rootOp->getDialect();
2282  Operation *capturedOp = nullptr;
2283  DominanceInfo domInfo;
2284 
2285  // Process in pre-order to check operations from outermost to innermost,
2286  // ensuring we only enter the region of an operation if it meets the criteria
2287  // for being captured. We stop the exploration of nested operations as soon as
2288  // we process a region holding no operations to be captured.
2289  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2290  if (op == rootOp)
2291  return WalkResult::advance();
2292 
2293  // Ignore operations of other dialects or omp operations with no regions,
2294  // because these will only be checked if they are siblings of an omp
2295  // operation that can potentially be captured.
2296  bool isOmpDialect = op->getDialect() == ompDialect;
2297  bool hasRegions = op->getNumRegions() > 0;
2298  if (!isOmpDialect || !hasRegions)
2299  return WalkResult::skip();
2300 
2301  // This operation cannot be captured if it can be executed more than once
2302  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2303  // be executed before all exits of the region (i.e. it doesn't dominate all
2304  // blocks with no successors reachable from the entry block).
2305  if (checkSingleMandatoryExec) {
2306  Region *parentRegion = op->getParentRegion();
2307  Block *parentBlock = op->getBlock();
2308 
2309  for (Block *successor : parentBlock->getSuccessors())
2310  if (successor->isReachable(parentBlock))
2311  return WalkResult::interrupt();
2312 
2313  for (Block &block : *parentRegion)
2314  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2315  !domInfo.dominates(parentBlock, &block))
2316  return WalkResult::interrupt();
2317  }
2318 
2319  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2320  // into nested operations.
2321  for (Operation &sibling : op->getParentRegion()->getOps())
2322  if (&sibling != op && !siblingAllowedFn(&sibling))
2323  return WalkResult::interrupt();
2324 
2325  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2326  // Otherwise, process the contents of this operation.
2327  capturedOp = op;
2328  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2329  : WalkResult::advance();
2330  });
2331 
2332  return capturedOp;
2333 }
2334 
2335 Operation *TargetOp::getInnermostCapturedOmpOp() {
2336  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2337 
2338  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2339  // effects, but don't include a memory write effect.
2340  return findCapturedOmpOp(
2341  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2342  if (!sibling)
2343  return false;
2344 
2345  if (ompDialect == sibling->getDialect())
2346  return sibling->hasTrait<OpTrait::IsTerminator>();
2347 
2348  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2350  effects;
2351  memOp.getEffects(effects);
2352  return !llvm::any_of(
2353  effects, [&](MemoryEffects::EffectInstance &effect) {
2354  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2355  isa<SideEffects::AutomaticAllocationScopeResource>(
2356  effect.getResource());
2357  });
2358  }
2359  return true;
2360  });
2361 }
2362 
2363 /// Check if we can promote SPMD kernel to No-Loop kernel.
2364 static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2365  WsloopOp *wsLoopOp) {
2366  // num_teams clause can break no-loop teams/threads assumption.
2367  if (teamsOp.getNumTeamsUpper())
2368  return false;
2369 
2370  // Reduction kernels are slower in no-loop mode.
2371  if (teamsOp.getNumReductionVars())
2372  return false;
2373  if (wsLoopOp->getNumReductionVars())
2374  return false;
2375 
2376  // Check if the user allows the promotion of kernels to no-loop mode.
2377  OffloadModuleInterface offloadMod =
2378  capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2379  if (!offloadMod)
2380  return false;
2381  auto ompFlags = offloadMod.getFlags();
2382  if (!ompFlags)
2383  return false;
2384  return ompFlags.getAssumeTeamsOversubscription() &&
2385  ompFlags.getAssumeThreadsOversubscription();
2386 }
2387 
2388 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2389  // A non-null captured op is only valid if it resides inside of a TargetOp
2390  // and is the result of calling getInnermostCapturedOmpOp() on it.
2391  TargetOp targetOp =
2392  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2393  assert((!capturedOp ||
2394  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2395  "unexpected captured op");
2396 
2397  // If it's not capturing a loop, it's a default target region.
2398  if (!isa_and_present<LoopNestOp>(capturedOp))
2399  return TargetRegionFlags::generic;
2400 
2401  // Get the innermost non-simd loop wrapper.
2402  SmallVector<LoopWrapperInterface> loopWrappers;
2403  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2404  assert(!loopWrappers.empty());
2405 
2406  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2407  if (isa<SimdOp>(innermostWrapper))
2408  innermostWrapper = std::next(innermostWrapper);
2409 
2410  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2411  if (numWrappers != 1 && numWrappers != 2)
2412  return TargetRegionFlags::generic;
2413 
2414  // Detect target-teams-distribute-parallel-wsloop[-simd].
2415  if (numWrappers == 2) {
2416  WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2417  if (!wsloopOp)
2418  return TargetRegionFlags::generic;
2419 
2420  innermostWrapper = std::next(innermostWrapper);
2421  if (!isa<DistributeOp>(innermostWrapper))
2422  return TargetRegionFlags::generic;
2423 
2424  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2425  if (!isa_and_present<ParallelOp>(parallelOp))
2426  return TargetRegionFlags::generic;
2427 
2428  TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2429  if (!teamsOp)
2430  return TargetRegionFlags::generic;
2431 
2432  if (teamsOp->getParentOp() == targetOp.getOperation()) {
2433  TargetRegionFlags result =
2434  TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2435  if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2436  result = result | TargetRegionFlags::no_loop;
2437  return result;
2438  }
2439  }
2440  // Detect target-teams-distribute[-simd] and target-teams-loop.
2441  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2442  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2443  if (!isa_and_present<TeamsOp>(teamsOp))
2444  return TargetRegionFlags::generic;
2445 
2446  if (teamsOp->getParentOp() != targetOp.getOperation())
2447  return TargetRegionFlags::generic;
2448 
2449  if (isa<LoopOp>(innermostWrapper))
2450  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2451 
2452  // Find single immediately nested captured omp.parallel and add spmd flag
2453  // (generic-spmd case).
2454  //
2455  // TODO: This shouldn't have to be done here, as it is too easy to break.
2456  // The openmp-opt pass should be updated to be able to promote kernels like
2457  // this from "Generic" to "Generic-SPMD". However, the use of the
2458  // `kmpc_distribute_static_loop` family of functions produced by the
2459  // OMPIRBuilder for these kernels prevents that from working.
2460  Dialect *ompDialect = targetOp->getDialect();
2461  Operation *nestedCapture = findCapturedOmpOp(
2462  capturedOp, /*checkSingleMandatoryExec=*/false,
2463  [&](Operation *sibling) {
2464  return sibling && (ompDialect != sibling->getDialect() ||
2465  sibling->hasTrait<OpTrait::IsTerminator>());
2466  });
2467 
2468  TargetRegionFlags result =
2469  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2470 
2471  if (!nestedCapture)
2472  return result;
2473 
2474  while (nestedCapture->getParentOp() != capturedOp)
2475  nestedCapture = nestedCapture->getParentOp();
2476 
2477  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2478  : result;
2479  }
2480  // Detect target-parallel-wsloop[-simd].
2481  else if (isa<WsloopOp>(innermostWrapper)) {
2482  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2483  if (!isa_and_present<ParallelOp>(parallelOp))
2484  return TargetRegionFlags::generic;
2485 
2486  if (parallelOp->getParentOp() == targetOp.getOperation())
2487  return TargetRegionFlags::spmd;
2488  }
2489 
2490  return TargetRegionFlags::generic;
2491 }
2492 
2493 //===----------------------------------------------------------------------===//
2494 // ParallelOp
2495 //===----------------------------------------------------------------------===//
2496 
2497 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2498  ArrayRef<NamedAttribute> attributes) {
2499  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2500  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2501  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2502  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2503  /*proc_bind_kind=*/nullptr,
2504  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2505  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2506  state.addAttributes(attributes);
2507 }
2508 
2509 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2510  const ParallelOperands &clauses) {
2511  MLIRContext *ctx = builder.getContext();
2512  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2513  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2514  makeArrayAttr(ctx, clauses.privateSyms),
2515  clauses.privateNeedsBarrier, clauses.procBindKind,
2516  clauses.reductionMod, clauses.reductionVars,
2517  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2518  makeArrayAttr(ctx, clauses.reductionSyms));
2519 }
2520 
2521 template <typename OpType>
2522 static LogicalResult verifyPrivateVarList(OpType &op) {
2523  auto privateVars = op.getPrivateVars();
2524  auto privateSyms = op.getPrivateSymsAttr();
2525 
2526  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2527  return success();
2528 
2529  auto numPrivateVars = privateVars.size();
2530  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2531 
2532  if (numPrivateVars != numPrivateSyms)
2533  return op.emitError() << "inconsistent number of private variables and "
2534  "privatizer op symbols, private vars: "
2535  << numPrivateVars
2536  << " vs. privatizer op symbols: " << numPrivateSyms;
2537 
2538  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2539  Type varType = std::get<0>(privateVarInfo).getType();
2540  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2541  PrivateClauseOp privatizerOp =
2542  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2543 
2544  if (privatizerOp == nullptr)
2545  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2546  << privateSym << "'";
2547 
2548  Type privatizerType = privatizerOp.getArgType();
2549 
2550  if (privatizerType && (varType != privatizerType))
2551  return op.emitError()
2552  << "type mismatch between a "
2553  << (privatizerOp.getDataSharingType() ==
2554  DataSharingClauseType::Private
2555  ? "private"
2556  : "firstprivate")
2557  << " variable and its privatizer op, var type: " << varType
2558  << " vs. privatizer op type: " << privatizerType;
2559  }
2560 
2561  return success();
2562 }
2563 
2564 LogicalResult ParallelOp::verify() {
2565  if (getAllocateVars().size() != getAllocatorVars().size())
2566  return emitError(
2567  "expected equal sizes for allocate and allocator variables");
2568 
2569  if (failed(verifyPrivateVarList(*this)))
2570  return failure();
2571 
2572  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2573  getReductionByref());
2574 }
2575 
2576 LogicalResult ParallelOp::verifyRegions() {
2577  auto distChildOps = getOps<DistributeOp>();
2578  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2579  if (numDistChildOps > 1)
2580  return emitError()
2581  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2582 
2583  if (numDistChildOps == 1) {
2584  if (!isComposite())
2585  return emitError()
2586  << "'omp.composite' attribute missing from composite operation";
2587 
2588  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2589  Operation &distributeOp = **distChildOps.begin();
2590  for (Operation &childOp : getOps()) {
2591  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2592  continue;
2593 
2594  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2595  return emitError() << "unexpected OpenMP operation inside of composite "
2596  "'omp.parallel': "
2597  << childOp.getName();
2598  }
2599  } else if (isComposite()) {
2600  return emitError()
2601  << "'omp.composite' attribute present in non-composite operation";
2602  }
2603  return success();
2604 }
2605 
2606 //===----------------------------------------------------------------------===//
2607 // TeamsOp
2608 //===----------------------------------------------------------------------===//
2609 
2611  while ((op = op->getParentOp()))
2612  if (isa<OpenMPDialect>(op->getDialect()))
2613  return false;
2614  return true;
2615 }
2616 
2617 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2618  const TeamsOperands &clauses) {
2619  MLIRContext *ctx = builder.getContext();
2620  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2621  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2622  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2623  /*private_vars=*/{}, /*private_syms=*/nullptr,
2624  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2625  clauses.reductionVars,
2626  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2627  makeArrayAttr(ctx, clauses.reductionSyms),
2628  clauses.threadLimit);
2629 }
2630 
2631 LogicalResult TeamsOp::verify() {
2632  // Check parent region
2633  // TODO If nested inside of a target region, also check that it does not
2634  // contain any statements, declarations or directives other than this
2635  // omp.teams construct. The issue is how to support the initialization of
2636  // this operation's own arguments (allow SSA values across omp.target?).
2637  Operation *op = getOperation();
2638  if (!isa<TargetOp>(op->getParentOp()) &&
2640  return emitError("expected to be nested inside of omp.target or not nested "
2641  "in any OpenMP dialect operations");
2642 
2643  // Check for num_teams clause restrictions
2644  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2645  auto numTeamsUpperBound = getNumTeamsUpper();
2646  if (!numTeamsUpperBound)
2647  return emitError("expected num_teams upper bound to be defined if the "
2648  "lower bound is defined");
2649  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2650  return emitError(
2651  "expected num_teams upper bound and lower bound to be the same type");
2652  }
2653 
2654  // Check for allocate clause restrictions
2655  if (getAllocateVars().size() != getAllocatorVars().size())
2656  return emitError(
2657  "expected equal sizes for allocate and allocator variables");
2658 
2659  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2660  getReductionByref());
2661 }
2662 
2663 //===----------------------------------------------------------------------===//
2664 // SectionOp
2665 //===----------------------------------------------------------------------===//
2666 
2667 OperandRange SectionOp::getPrivateVars() {
2668  return getParentOp().getPrivateVars();
2669 }
2670 
2671 OperandRange SectionOp::getReductionVars() {
2672  return getParentOp().getReductionVars();
2673 }
2674 
2675 //===----------------------------------------------------------------------===//
2676 // SectionsOp
2677 //===----------------------------------------------------------------------===//
2678 
2679 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2680  const SectionsOperands &clauses) {
2681  MLIRContext *ctx = builder.getContext();
2682  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2683  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2684  clauses.nowait, /*private_vars=*/{},
2685  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2686  clauses.reductionMod, clauses.reductionVars,
2687  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2688  makeArrayAttr(ctx, clauses.reductionSyms));
2689 }
2690 
2691 LogicalResult SectionsOp::verify() {
2692  if (getAllocateVars().size() != getAllocatorVars().size())
2693  return emitError(
2694  "expected equal sizes for allocate and allocator variables");
2695 
2696  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2697  getReductionByref());
2698 }
2699 
2700 LogicalResult SectionsOp::verifyRegions() {
2701  for (auto &inst : *getRegion().begin()) {
2702  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2703  return emitOpError()
2704  << "expected omp.section op or terminator op inside region";
2705  }
2706  }
2707 
2708  return success();
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // SingleOp
2713 //===----------------------------------------------------------------------===//
2714 
2715 void SingleOp::build(OpBuilder &builder, OperationState &state,
2716  const SingleOperands &clauses) {
2717  MLIRContext *ctx = builder.getContext();
2718  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2719  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2720  clauses.copyprivateVars,
2721  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2722  /*private_vars=*/{}, /*private_syms=*/nullptr,
2723  /*private_needs_barrier=*/nullptr);
2724 }
2725 
2726 LogicalResult SingleOp::verify() {
2727  // Check for allocate clause restrictions
2728  if (getAllocateVars().size() != getAllocatorVars().size())
2729  return emitError(
2730  "expected equal sizes for allocate and allocator variables");
2731 
2732  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2733  getCopyprivateSyms());
2734 }
2735 
2736 //===----------------------------------------------------------------------===//
2737 // WorkshareOp
2738 //===----------------------------------------------------------------------===//
2739 
2740 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2741  const WorkshareOperands &clauses) {
2742  WorkshareOp::build(builder, state, clauses.nowait);
2743 }
2744 
2745 //===----------------------------------------------------------------------===//
2746 // WorkshareLoopWrapperOp
2747 //===----------------------------------------------------------------------===//
2748 
2749 LogicalResult WorkshareLoopWrapperOp::verify() {
2750  if (!(*this)->getParentOfType<WorkshareOp>())
2751  return emitOpError() << "must be nested in an omp.workshare";
2752  return success();
2753 }
2754 
2755 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2756  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2757  getNestedWrapper())
2758  return emitOpError() << "expected to be a standalone loop wrapper";
2759 
2760  return success();
2761 }
2762 
2763 //===----------------------------------------------------------------------===//
2764 // LoopWrapperInterface
2765 //===----------------------------------------------------------------------===//
2766 
2767 LogicalResult LoopWrapperInterface::verifyImpl() {
2768  Operation *op = this->getOperation();
2769  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2771  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2772  "and `SingleBlock` traits";
2773 
2774  if (op->getNumRegions() != 1)
2775  return emitOpError() << "loop wrapper does not contain exactly one region";
2776 
2777  Region &region = op->getRegion(0);
2778  if (range_size(region.getOps()) != 1)
2779  return emitOpError()
2780  << "loop wrapper does not contain exactly one nested op";
2781 
2782  Operation &firstOp = *region.op_begin();
2783  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2784  return emitOpError() << "nested in loop wrapper is not another loop "
2785  "wrapper or `omp.loop_nest`";
2786 
2787  return success();
2788 }
2789 
2790 //===----------------------------------------------------------------------===//
2791 // LoopOp
2792 //===----------------------------------------------------------------------===//
2793 
2794 void LoopOp::build(OpBuilder &builder, OperationState &state,
2795  const LoopOperands &clauses) {
2796  MLIRContext *ctx = builder.getContext();
2797 
2798  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2799  makeArrayAttr(ctx, clauses.privateSyms),
2800  clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2801  clauses.reductionMod, clauses.reductionVars,
2802  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2803  makeArrayAttr(ctx, clauses.reductionSyms));
2804 }
2805 
2806 LogicalResult LoopOp::verify() {
2807  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2808  getReductionByref());
2809 }
2810 
2811 LogicalResult LoopOp::verifyRegions() {
2812  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2813  getNestedWrapper())
2814  return emitOpError() << "expected to be a standalone loop wrapper";
2815 
2816  return success();
2817 }
2818 
2819 //===----------------------------------------------------------------------===//
2820 // WsloopOp
2821 //===----------------------------------------------------------------------===//
2822 
2823 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2824  ArrayRef<NamedAttribute> attributes) {
2825  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2826  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2827  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2828  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2829  /*private_needs_barrier=*/false,
2830  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2831  /*reduction_byref=*/nullptr,
2832  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2833  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2834  /*schedule_simd=*/false);
2835  state.addAttributes(attributes);
2836 }
2837 
2838 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2839  const WsloopOperands &clauses) {
2840  MLIRContext *ctx = builder.getContext();
2841  // TODO: Store clauses in op: allocateVars, allocatorVars
2842  WsloopOp::build(
2843  builder, state,
2844  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2845  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2846  clauses.ordered, clauses.privateVars,
2847  makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2848  clauses.reductionMod, clauses.reductionVars,
2849  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2850  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2851  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2852 }
2853 
2854 LogicalResult WsloopOp::verify() {
2855  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2856  getReductionByref());
2857 }
2858 
2859 LogicalResult WsloopOp::verifyRegions() {
2860  bool isCompositeChildLeaf =
2861  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2862 
2863  if (LoopWrapperInterface nested = getNestedWrapper()) {
2864  if (!isComposite())
2865  return emitError()
2866  << "'omp.composite' attribute missing from composite wrapper";
2867 
2868  // Check for the allowed leaf constructs that may appear in a composite
2869  // construct directly after DO/FOR.
2870  if (!isa<SimdOp>(nested))
2871  return emitError() << "only supported nested wrapper is 'omp.simd'";
2872 
2873  } else if (isComposite() && !isCompositeChildLeaf) {
2874  return emitError()
2875  << "'omp.composite' attribute present in non-composite wrapper";
2876  } else if (!isComposite() && isCompositeChildLeaf) {
2877  return emitError()
2878  << "'omp.composite' attribute missing from composite wrapper";
2879  }
2880 
2881  return success();
2882 }
2883 
2884 //===----------------------------------------------------------------------===//
2885 // Simd construct [2.9.3.1]
2886 //===----------------------------------------------------------------------===//
2887 
2888 void SimdOp::build(OpBuilder &builder, OperationState &state,
2889  const SimdOperands &clauses) {
2890  MLIRContext *ctx = builder.getContext();
2891  // TODO Store clauses in op: linearVars, linearStepVars
2892  SimdOp::build(builder, state, clauses.alignedVars,
2893  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2894  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2895  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2896  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2897  clauses.privateNeedsBarrier, clauses.reductionMod,
2898  clauses.reductionVars,
2899  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2900  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2901  clauses.simdlen);
2902 }
2903 
2904 LogicalResult SimdOp::verify() {
2905  if (getSimdlen().has_value() && getSafelen().has_value() &&
2906  getSimdlen().value() > getSafelen().value())
2907  return emitOpError()
2908  << "simdlen clause and safelen clause are both present, but the "
2909  "simdlen value is not less than or equal to safelen value";
2910 
2911  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2912  return failure();
2913 
2914  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2915  return failure();
2916 
2917  bool isCompositeChildLeaf =
2918  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2919 
2920  if (!isComposite() && isCompositeChildLeaf)
2921  return emitError()
2922  << "'omp.composite' attribute missing from composite wrapper";
2923 
2924  if (isComposite() && !isCompositeChildLeaf)
2925  return emitError()
2926  << "'omp.composite' attribute present in non-composite wrapper";
2927 
2928  // Firstprivate is not allowed for SIMD in the standard. Check that none of
2929  // the private decls are for firstprivate.
2930  std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2931  if (privateSyms) {
2932  for (const Attribute &sym : *privateSyms) {
2933  auto symRef = cast<SymbolRefAttr>(sym);
2934  omp::PrivateClauseOp privatizer =
2935  SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2936  getOperation(), symRef);
2937  if (!privatizer)
2938  return emitError() << "Cannot find privatizer '" << symRef << "'";
2939  if (privatizer.getDataSharingType() ==
2940  DataSharingClauseType::FirstPrivate)
2941  return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2942  }
2943  }
2944 
2945  return success();
2946 }
2947 
2948 LogicalResult SimdOp::verifyRegions() {
2949  if (getNestedWrapper())
2950  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2951 
2952  return success();
2953 }
2954 
2955 //===----------------------------------------------------------------------===//
2956 // Distribute construct [2.9.4.1]
2957 //===----------------------------------------------------------------------===//
2958 
2959 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2960  const DistributeOperands &clauses) {
2961  DistributeOp::build(builder, state, clauses.allocateVars,
2962  clauses.allocatorVars, clauses.distScheduleStatic,
2963  clauses.distScheduleChunkSize, clauses.order,
2964  clauses.orderMod, clauses.privateVars,
2965  makeArrayAttr(builder.getContext(), clauses.privateSyms),
2966  clauses.privateNeedsBarrier);
2967 }
2968 
2969 LogicalResult DistributeOp::verify() {
2970  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2971  return emitOpError() << "chunk size set without "
2972  "dist_schedule_static being present";
2973 
2974  if (getAllocateVars().size() != getAllocatorVars().size())
2975  return emitError(
2976  "expected equal sizes for allocate and allocator variables");
2977 
2978  return success();
2979 }
2980 
2981 LogicalResult DistributeOp::verifyRegions() {
2982  if (LoopWrapperInterface nested = getNestedWrapper()) {
2983  if (!isComposite())
2984  return emitError()
2985  << "'omp.composite' attribute missing from composite wrapper";
2986  // Check for the allowed leaf constructs that may appear in a composite
2987  // construct directly after DISTRIBUTE.
2988  if (isa<WsloopOp>(nested)) {
2989  Operation *parentOp = (*this)->getParentOp();
2990  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2991  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2992  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2993  "when a composite 'omp.parallel' is the direct "
2994  "parent";
2995  }
2996  } else if (!isa<SimdOp>(nested))
2997  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2998  "'omp.wsloop'";
2999  } else if (isComposite()) {
3000  return emitError()
3001  << "'omp.composite' attribute present in non-composite wrapper";
3002  }
3003 
3004  return success();
3005 }
3006 
3007 //===----------------------------------------------------------------------===//
3008 // DeclareMapperOp / DeclareMapperInfoOp
3009 //===----------------------------------------------------------------------===//
3010 
3011 LogicalResult DeclareMapperInfoOp::verify() {
3012  return verifyMapClause(*this, getMapVars());
3013 }
3014 
3015 LogicalResult DeclareMapperOp::verifyRegions() {
3016  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3017  getRegion().getBlocks().front().getTerminator()))
3018  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3019 
3020  return success();
3021 }
3022 
3023 //===----------------------------------------------------------------------===//
3024 // DeclareReductionOp
3025 //===----------------------------------------------------------------------===//
3026 
3027 LogicalResult DeclareReductionOp::verifyRegions() {
3028  if (!getAllocRegion().empty()) {
3029  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3030  if (yieldOp.getResults().size() != 1 ||
3031  yieldOp.getResults().getTypes()[0] != getType())
3032  return emitOpError() << "expects alloc region to yield a value "
3033  "of the reduction type";
3034  }
3035  }
3036 
3037  if (getInitializerRegion().empty())
3038  return emitOpError() << "expects non-empty initializer region";
3039  Block &initializerEntryBlock = getInitializerRegion().front();
3040 
3041  if (initializerEntryBlock.getNumArguments() == 1) {
3042  if (!getAllocRegion().empty())
3043  return emitOpError() << "expects two arguments to the initializer region "
3044  "when an allocation region is used";
3045  } else if (initializerEntryBlock.getNumArguments() == 2) {
3046  if (getAllocRegion().empty())
3047  return emitOpError() << "expects one argument to the initializer region "
3048  "when no allocation region is used";
3049  } else {
3050  return emitOpError()
3051  << "expects one or two arguments to the initializer region";
3052  }
3053 
3054  for (mlir::Value arg : initializerEntryBlock.getArguments())
3055  if (arg.getType() != getType())
3056  return emitOpError() << "expects initializer region argument to match "
3057  "the reduction type";
3058 
3059  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3060  if (yieldOp.getResults().size() != 1 ||
3061  yieldOp.getResults().getTypes()[0] != getType())
3062  return emitOpError() << "expects initializer region to yield a value "
3063  "of the reduction type";
3064  }
3065 
3066  if (getReductionRegion().empty())
3067  return emitOpError() << "expects non-empty reduction region";
3068  Block &reductionEntryBlock = getReductionRegion().front();
3069  if (reductionEntryBlock.getNumArguments() != 2 ||
3070  reductionEntryBlock.getArgumentTypes()[0] !=
3071  reductionEntryBlock.getArgumentTypes()[1] ||
3072  reductionEntryBlock.getArgumentTypes()[0] != getType())
3073  return emitOpError() << "expects reduction region with two arguments of "
3074  "the reduction type";
3075  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3076  if (yieldOp.getResults().size() != 1 ||
3077  yieldOp.getResults().getTypes()[0] != getType())
3078  return emitOpError() << "expects reduction region to yield a value "
3079  "of the reduction type";
3080  }
3081 
3082  if (!getAtomicReductionRegion().empty()) {
3083  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3084  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3085  atomicReductionEntryBlock.getArgumentTypes()[0] !=
3086  atomicReductionEntryBlock.getArgumentTypes()[1])
3087  return emitOpError() << "expects atomic reduction region with two "
3088  "arguments of the same type";
3089  auto ptrType = llvm::dyn_cast<PointerLikeType>(
3090  atomicReductionEntryBlock.getArgumentTypes()[0]);
3091  if (!ptrType ||
3092  (ptrType.getElementType() && ptrType.getElementType() != getType()))
3093  return emitOpError() << "expects atomic reduction region arguments to "
3094  "be accumulators containing the reduction type";
3095  }
3096 
3097  if (getCleanupRegion().empty())
3098  return success();
3099  Block &cleanupEntryBlock = getCleanupRegion().front();
3100  if (cleanupEntryBlock.getNumArguments() != 1 ||
3101  cleanupEntryBlock.getArgument(0).getType() != getType())
3102  return emitOpError() << "expects cleanup region with one argument "
3103  "of the reduction type";
3104 
3105  return success();
3106 }
3107 
3108 //===----------------------------------------------------------------------===//
3109 // TaskOp
3110 //===----------------------------------------------------------------------===//
3111 
3112 void TaskOp::build(OpBuilder &builder, OperationState &state,
3113  const TaskOperands &clauses) {
3114  MLIRContext *ctx = builder.getContext();
3115  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3116  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3117  clauses.final, clauses.ifExpr, clauses.inReductionVars,
3118  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3119  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3120  clauses.priority, /*private_vars=*/clauses.privateVars,
3121  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3122  clauses.privateNeedsBarrier, clauses.untied,
3123  clauses.eventHandle);
3124 }
3125 
3126 LogicalResult TaskOp::verify() {
3127  LogicalResult verifyDependVars =
3128  verifyDependVarList(*this, getDependKinds(), getDependVars());
3129  return failed(verifyDependVars)
3130  ? verifyDependVars
3131  : verifyReductionVarList(*this, getInReductionSyms(),
3132  getInReductionVars(),
3133  getInReductionByref());
3134 }
3135 
3136 //===----------------------------------------------------------------------===//
3137 // TaskgroupOp
3138 //===----------------------------------------------------------------------===//
3139 
3140 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3141  const TaskgroupOperands &clauses) {
3142  MLIRContext *ctx = builder.getContext();
3143  TaskgroupOp::build(builder, state, clauses.allocateVars,
3144  clauses.allocatorVars, clauses.taskReductionVars,
3145  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3146  makeArrayAttr(ctx, clauses.taskReductionSyms));
3147 }
3148 
3149 LogicalResult TaskgroupOp::verify() {
3150  return verifyReductionVarList(*this, getTaskReductionSyms(),
3151  getTaskReductionVars(),
3152  getTaskReductionByref());
3153 }
3154 
3155 //===----------------------------------------------------------------------===//
3156 // TaskloopOp
3157 //===----------------------------------------------------------------------===//
3158 
3159 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3160  const TaskloopOperands &clauses) {
3161  MLIRContext *ctx = builder.getContext();
3162  TaskloopOp::build(
3163  builder, state, clauses.allocateVars, clauses.allocatorVars,
3164  clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3165  clauses.inReductionVars,
3166  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3167  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3168  clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3169  /*private_vars=*/clauses.privateVars,
3170  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3171  clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3172  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3173  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3174 }
3175 
3176 LogicalResult TaskloopOp::verify() {
3177  if (getAllocateVars().size() != getAllocatorVars().size())
3178  return emitError(
3179  "expected equal sizes for allocate and allocator variables");
3180  if (failed(verifyReductionVarList(*this, getReductionSyms(),
3181  getReductionVars(), getReductionByref())) ||
3182  failed(verifyReductionVarList(*this, getInReductionSyms(),
3183  getInReductionVars(),
3184  getInReductionByref())))
3185  return failure();
3186 
3187  if (!getReductionVars().empty() && getNogroup())
3188  return emitError("if a reduction clause is present on the taskloop "
3189  "directive, the nogroup clause must not be specified");
3190  for (auto var : getReductionVars()) {
3191  if (llvm::is_contained(getInReductionVars(), var))
3192  return emitError("the same list item cannot appear in both a reduction "
3193  "and an in_reduction clause");
3194  }
3195 
3196  if (getGrainsize() && getNumTasks()) {
3197  return emitError(
3198  "the grainsize clause and num_tasks clause are mutually exclusive and "
3199  "may not appear on the same taskloop directive");
3200  }
3201 
3202  return success();
3203 }
3204 
3205 LogicalResult TaskloopOp::verifyRegions() {
3206  if (LoopWrapperInterface nested = getNestedWrapper()) {
3207  if (!isComposite())
3208  return emitError()
3209  << "'omp.composite' attribute missing from composite wrapper";
3210 
3211  // Check for the allowed leaf constructs that may appear in a composite
3212  // construct directly after TASKLOOP.
3213  if (!isa<SimdOp>(nested))
3214  return emitError() << "only supported nested wrapper is 'omp.simd'";
3215  } else if (isComposite()) {
3216  return emitError()
3217  << "'omp.composite' attribute present in non-composite wrapper";
3218  }
3219 
3220  return success();
3221 }
3222 
3223 //===----------------------------------------------------------------------===//
3224 // LoopNestOp
3225 //===----------------------------------------------------------------------===//
3226 
3227 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3228  // Parse an opening `(` followed by induction variables followed by `)`
3231  Type loopVarType;
3232  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
3233  parser.parseColonType(loopVarType) ||
3234  // Parse loop bounds.
3235  parser.parseEqual() ||
3236  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3237  parser.parseKeyword("to") ||
3238  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3239  return failure();
3240 
3241  for (auto &iv : ivs)
3242  iv.type = loopVarType;
3243 
3244  auto *ctx = parser.getBuilder().getContext();
3245  // Parse "inclusive" flag.
3246  if (succeeded(parser.parseOptionalKeyword("inclusive")))
3247  result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3248 
3249  // Parse step values.
3251  if (parser.parseKeyword("step") ||
3252  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3253  return failure();
3254 
3255  // Parse collapse
3256  int64_t value = 0;
3257  if (!parser.parseOptionalKeyword("collapse") &&
3258  (parser.parseLParen() || parser.parseInteger(value) ||
3259  parser.parseRParen()))
3260  return failure();
3261  if (value > 1)
3262  result.addAttribute(
3263  "collapse_num_loops",
3264  IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3265 
3266  // Parse tiles
3267  SmallVector<int64_t> tiles;
3268  auto parseTiles = [&]() -> ParseResult {
3269  int64_t tile;
3270  if (parser.parseInteger(tile))
3271  return failure();
3272  tiles.push_back(tile);
3273  return success();
3274  };
3275 
3276  if (!parser.parseOptionalKeyword("tiles") &&
3277  (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3278  parser.parseRParen()))
3279  return failure();
3280 
3281  if (tiles.size() > 0)
3282  result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3283 
3284  // Parse the body.
3285  Region *region = result.addRegion();
3286  if (parser.parseRegion(*region, ivs))
3287  return failure();
3288 
3289  // Resolve operands.
3290  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3291  parser.resolveOperands(ubs, loopVarType, result.operands) ||
3292  parser.resolveOperands(steps, loopVarType, result.operands))
3293  return failure();
3294 
3295  // Parse the optional attribute list.
3296  return parser.parseOptionalAttrDict(result.attributes);
3297 }
3298 
3300  Region &region = getRegion();
3301  auto args = region.getArguments();
3302  p << " (" << args << ") : " << args[0].getType() << " = ("
3303  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3304  if (getLoopInclusive())
3305  p << "inclusive ";
3306  p << "step (" << getLoopSteps() << ") ";
3307  if (int64_t numCollapse = getCollapseNumLoops())
3308  if (numCollapse > 1)
3309  p << "collapse(" << numCollapse << ") ";
3310 
3311  if (const auto tiles = getTileSizes())
3312  p << "tiles(" << tiles.value() << ") ";
3313 
3314  p.printRegion(region, /*printEntryBlockArgs=*/false);
3315 }
3316 
3317 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3318  const LoopNestOperands &clauses) {
3319  MLIRContext *ctx = builder.getContext();
3320  LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3321  clauses.loopLowerBounds, clauses.loopUpperBounds,
3322  clauses.loopSteps, clauses.loopInclusive,
3323  makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3324 }
3325 
3326 LogicalResult LoopNestOp::verify() {
3327  if (getLoopLowerBounds().empty())
3328  return emitOpError() << "must represent at least one loop";
3329 
3330  if (getLoopLowerBounds().size() != getIVs().size())
3331  return emitOpError() << "number of range arguments and IVs do not match";
3332 
3333  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3334  if (lb.getType() != iv.getType())
3335  return emitOpError()
3336  << "range argument type does not match corresponding IV type";
3337  }
3338 
3339  uint64_t numIVs = getIVs().size();
3340 
3341  if (const auto &numCollapse = getCollapseNumLoops())
3342  if (numCollapse > numIVs)
3343  return emitOpError()
3344  << "collapse value is larger than the number of loops";
3345 
3346  if (const auto &tiles = getTileSizes())
3347  if (tiles.value().size() > numIVs)
3348  return emitOpError() << "too few canonical loops for tile dimensions";
3349 
3350  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3351  return emitOpError() << "expects parent op to be a loop wrapper";
3352 
3353  return success();
3354 }
3355 
3356 void LoopNestOp::gatherWrappers(
3358  Operation *parent = (*this)->getParentOp();
3359  while (auto wrapper =
3360  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3361  wrappers.push_back(wrapper);
3362  parent = parent->getParentOp();
3363  }
3364 }
3365 
3366 //===----------------------------------------------------------------------===//
3367 // OpenMP canonical loop handling
3368 //===----------------------------------------------------------------------===//
3369 
3370 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3372 
3373  // Defining a CLI for a generated loop is optional; if there is none then
3374  // there is no followup-tranformation
3375  if (!cli)
3376  return {{}, nullptr, nullptr};
3377 
3378  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3379  "Unexpected type of cli");
3380 
3381  NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3382  OpOperand *gen = nullptr;
3383  OpOperand *cons = nullptr;
3384  for (OpOperand &use : cli.getUses()) {
3385  auto op = cast<LoopTransformationInterface>(use.getOwner());
3386 
3387  unsigned opnum = use.getOperandNumber();
3388  if (op.isGeneratee(opnum)) {
3389  assert(!gen && "Each CLI may have at most one def");
3390  gen = &use;
3391  } else if (op.isApplyee(opnum)) {
3392  assert(!cons && "Each CLI may have at most one consumer");
3393  cons = &use;
3394  } else {
3395  llvm_unreachable("Unexpected operand for a CLI");
3396  }
3397  }
3398 
3399  return {create, gen, cons};
3400 }
3401 
3402 void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3403  ::mlir::OperationState &odsState) {
3404  odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3405 }
3406 
3407 void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3408  Value result = getResult();
3409  auto [newCli, gen, cons] = decodeCli(result);
3410 
3411  // Structured binding `gen` cannot be captured in lambdas before C++20
3412  OpOperand *generator = gen;
3413 
3414  // Derive the CLI variable name from its generator:
3415  // * "canonloop" for omp.canonical_loop
3416  // * custom name for loop transformation generatees
3417  // * "cli" as fallback if no generator
3418  // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3419  // at that level
3420  // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3421  // the index of that region
3422  std::string cliName{"cli"};
3423  if (gen) {
3424  cliName =
3425  TypeSwitch<Operation *, std::string>(gen->getOwner())
3426  .Case([&](CanonicalLoopOp op) {
3427  return generateLoopNestingName("canonloop", op);
3428  })
3429  .Case([&](UnrollHeuristicOp op) -> std::string {
3430  llvm_unreachable("heuristic unrolling does not generate a loop");
3431  })
3432  .Case([&](TileOp op) -> std::string {
3433  auto [generateesFirst, generateesCount] =
3434  op.getGenerateesODSOperandIndexAndLength();
3435  unsigned firstGrid = generateesFirst;
3436  unsigned firstIntratile = generateesFirst + generateesCount / 2;
3437  unsigned end = generateesFirst + generateesCount;
3438  unsigned opnum = generator->getOperandNumber();
3439  // In the OpenMP apply and looprange clauses, indices are 1-based
3440  if (firstGrid <= opnum && opnum < firstIntratile) {
3441  unsigned gridnum = opnum - firstGrid + 1;
3442  return ("grid" + Twine(gridnum)).str();
3443  }
3444  if (firstIntratile <= opnum && opnum < end) {
3445  unsigned intratilenum = opnum - firstIntratile + 1;
3446  return ("intratile" + Twine(intratilenum)).str();
3447  }
3448  llvm_unreachable("Unexpected generatee argument");
3449  })
3450  .DefaultUnreachable("TODO: Custom name for this operation");
3451  }
3452 
3453  setNameFn(result, cliName);
3454 }
3455 
3456 LogicalResult NewCliOp::verify() {
3457  Value cli = getResult();
3458 
3459  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3460  "Unexpected type of cli");
3461 
3462  // Check that the CLI is used in at most generator and one consumer
3463  OpOperand *gen = nullptr;
3464  OpOperand *cons = nullptr;
3465  for (mlir::OpOperand &use : cli.getUses()) {
3466  auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3467 
3468  unsigned opnum = use.getOperandNumber();
3469  if (op.isGeneratee(opnum)) {
3470  if (gen) {
3471  InFlightDiagnostic error =
3472  emitOpError("CLI must have at most one generator");
3473  error.attachNote(gen->getOwner()->getLoc())
3474  .append("first generator here:");
3475  error.attachNote(use.getOwner()->getLoc())
3476  .append("second generator here:");
3477  return error;
3478  }
3479 
3480  gen = &use;
3481  } else if (op.isApplyee(opnum)) {
3482  if (cons) {
3483  InFlightDiagnostic error =
3484  emitOpError("CLI must have at most one consumer");
3485  error.attachNote(cons->getOwner()->getLoc())
3486  .append("first consumer here:")
3487  .appendOp(*cons->getOwner(),
3488  OpPrintingFlags().printGenericOpForm());
3489  error.attachNote(use.getOwner()->getLoc())
3490  .append("second consumer here:")
3491  .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3492  return error;
3493  }
3494 
3495  cons = &use;
3496  } else {
3497  llvm_unreachable("Unexpected operand for a CLI");
3498  }
3499  }
3500 
3501  // If the CLI is source of a transformation, it must have a generator
3502  if (cons && !gen) {
3503  InFlightDiagnostic error = emitOpError("CLI has no generator");
3504  error.attachNote(cons->getOwner()->getLoc())
3505  .append("see consumer here: ")
3506  .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3507  return error;
3508  }
3509 
3510  return success();
3511 }
3512 
3513 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3514  Value tripCount) {
3515  odsState.addOperands(tripCount);
3516  odsState.addOperands(Value());
3517  (void)odsState.addRegion();
3518 }
3519 
3520 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3521  Value tripCount, ::mlir::Value cli) {
3522  odsState.addOperands(tripCount);
3523  odsState.addOperands(cli);
3524  (void)odsState.addRegion();
3525 }
3526 
3527 void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3528  setNameFn(&getRegion().front(), "body_entry");
3529 }
3530 
3531 void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3532  OpAsmSetValueNameFn setNameFn) {
3533  std::string ivName = generateLoopNestingName("iv", *this);
3534  setNameFn(region.getArgument(0), ivName);
3535 }
3536 
3538  if (getCli())
3539  p << '(' << getCli() << ')';
3540  p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3541  << " in range(" << getTripCount() << ") ";
3542 
3543  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3544  /*printBlockTerminators=*/true);
3545 
3546  p.printOptionalAttrDict((*this)->getAttrs());
3547 }
3548 
3549 mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3550  ::mlir::OperationState &result) {
3551  CanonicalLoopInfoType cliType =
3553 
3554  // Parse (optional) omp.cli identifier
3556  SmallVector<mlir::Value, 1> cliOperand;
3557  if (!parser.parseOptionalLParen()) {
3558  if (parser.parseOperand(cli) ||
3559  parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3560  return failure();
3561  }
3562 
3563  // We derive the type of tripCount from inductionVariable. MLIR requires the
3564  // type of tripCount to be known when calling resolveOperand so we have parse
3565  // the type before processing the inductionVariable.
3566  OpAsmParser::Argument inductionVariable;
3568  if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3569  parser.parseKeyword("in") || parser.parseKeyword("range") ||
3570  parser.parseLParen() || parser.parseOperand(tripcount) ||
3571  parser.parseRParen() ||
3572  parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3573  return failure();
3574 
3575  // Parse the loop body.
3576  Region *region = result.addRegion();
3577  if (parser.parseRegion(*region, {inductionVariable}))
3578  return failure();
3579 
3580  // We parsed the cli operand forst, but because it is optional, it must be
3581  // last in the operand list.
3582  result.operands.append(cliOperand);
3583 
3584  // Parse the optional attribute list.
3585  if (parser.parseOptionalAttrDict(result.attributes))
3586  return failure();
3587 
3588  return mlir::success();
3589 }
3590 
3591 LogicalResult CanonicalLoopOp::verify() {
3592  // The region's entry must accept the induction variable
3593  // It can also be empty if just created
3594  if (!getRegion().empty()) {
3595  Region &region = getRegion();
3596  if (region.getNumArguments() != 1)
3597  return emitOpError(
3598  "Canonical loop region must have exactly one argument");
3599 
3600  if (getInductionVar().getType() != getTripCount().getType())
3601  return emitOpError(
3602  "Region argument must be the same type as the trip count");
3603  }
3604 
3605  return success();
3606 }
3607 
3608 Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3609 
3610 std::pair<unsigned, unsigned>
3611 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3612  // No applyees
3613  return {0, 0};
3614 }
3615 
3616 std::pair<unsigned, unsigned>
3617 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3618  return getODSOperandIndexAndLength(odsIndex_cli);
3619 }
3620 
3621 //===----------------------------------------------------------------------===//
3622 // UnrollHeuristicOp
3623 //===----------------------------------------------------------------------===//
3624 
3625 void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3626  ::mlir::OperationState &odsState,
3627  ::mlir::Value cli) {
3628  odsState.addOperands(cli);
3629 }
3630 
3632  p << '(' << getApplyee() << ')';
3633 
3634  p.printOptionalAttrDict((*this)->getAttrs());
3635 }
3636 
3637 mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3638  ::mlir::OperationState &result) {
3639  auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3640 
3641  if (parser.parseLParen())
3642  return failure();
3643 
3645  if (parser.parseOperand(applyee) ||
3646  parser.resolveOperand(applyee, cliType, result.operands))
3647  return failure();
3648 
3649  if (parser.parseRParen())
3650  return failure();
3651 
3652  // Optional output loop (full unrolling has none)
3653  if (!parser.parseOptionalArrow()) {
3654  if (parser.parseLParen() || parser.parseRParen())
3655  return failure();
3656  }
3657 
3658  // Parse the optional attribute list.
3659  if (parser.parseOptionalAttrDict(result.attributes))
3660  return failure();
3661 
3662  return mlir::success();
3663 }
3664 
3665 std::pair<unsigned, unsigned>
3666 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3667  return getODSOperandIndexAndLength(odsIndex_applyee);
3668 }
3669 
3670 std::pair<unsigned, unsigned>
3671 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3672  return {0, 0};
3673 }
3674 
3675 //===----------------------------------------------------------------------===//
3676 // TileOp
3677 //===----------------------------------------------------------------------===//
3678 
3679 static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3680  OperandRange generatees,
3681  OperandRange applyees) {
3682  if (!generatees.empty())
3683  p << '(' << llvm::interleaved(generatees) << ')';
3684 
3685  if (!applyees.empty())
3686  p << " <- (" << llvm::interleaved(applyees) << ')';
3687 }
3688 
3689 static ParseResult parseLoopTransformClis(
3690  OpAsmParser &parser,
3693  if (parser.parseOptionalLess()) {
3694  // Syntax 1: generatees present
3695 
3696  if (parser.parseOperandList(generateesOperands,
3698  return failure();
3699 
3700  if (parser.parseLess())
3701  return failure();
3702  } else {
3703  // Syntax 2: generatees omitted
3704  }
3705 
3706  // Parse `<-` (`<` has already been parsed)
3707  if (parser.parseMinus())
3708  return failure();
3709 
3710  if (parser.parseOperandList(applyeesOperands,
3712  return failure();
3713 
3714  return success();
3715 }
3716 
3717 LogicalResult TileOp::verify() {
3718  if (getApplyees().empty())
3719  return emitOpError() << "must apply to at least one loop";
3720 
3721  if (getSizes().size() != getApplyees().size())
3722  return emitOpError() << "there must be one tile size for each applyee";
3723 
3724  if (!getGeneratees().empty() &&
3725  2 * getSizes().size() != getGeneratees().size())
3726  return emitOpError()
3727  << "expecting two times the number of generatees than applyees";
3728 
3729  DenseSet<Value> parentIVs;
3730 
3731  Value parent = getApplyees().front();
3732  for (auto &&applyee : llvm::drop_begin(getApplyees())) {
3733  auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
3734  auto [create, gen, cons] = decodeCli(applyee);
3735 
3736  if (!parentGen)
3737  return emitOpError() << "applyee CLI has no generator";
3738 
3739  auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3740  if (!parentGen)
3741  return emitOpError()
3742  << "currently only supports omp.canonical_loop as applyee";
3743 
3744  parentIVs.insert(parentLoop.getInductionVar());
3745 
3746  if (!gen)
3747  return emitOpError() << "applyee CLI has no generator";
3748  auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3749  if (!loop)
3750  return emitOpError()
3751  << "currently only supports omp.canonical_loop as applyee";
3752 
3753  // Canonical loop must be perfectly nested, i.e. the body of the parent must
3754  // only contain the omp.canonical_loop of the nested loops, and
3755  // omp.terminator
3756  bool isPerfectlyNested = [&]() {
3757  auto &parentBody = parentLoop.getRegion();
3758  if (!parentBody.hasOneBlock())
3759  return false;
3760  auto &parentBlock = parentBody.getBlocks().front();
3761 
3762  auto nestedLoopIt = parentBlock.begin();
3763  if (nestedLoopIt == parentBlock.end() ||
3764  (&*nestedLoopIt != loop.getOperation()))
3765  return false;
3766 
3767  auto termIt = std::next(nestedLoopIt);
3768  if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3769  return false;
3770 
3771  if (std::next(termIt) != parentBlock.end())
3772  return false;
3773 
3774  return true;
3775  }();
3776  if (!isPerfectlyNested)
3777  return emitOpError() << "tiled loop nest must be perfectly nested";
3778 
3779  if (parentIVs.contains(loop.getTripCount()))
3780  return emitOpError() << "tiled loop nest must be rectangular";
3781 
3782  parent = applyee;
3783  }
3784 
3785  // TODO: The tile sizes must be computed before the loop, but checking this
3786  // requires dominance analysis. For instance:
3787  //
3788  // %canonloop = omp.new_cli
3789  // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3790  // // write to %x
3791  // omp.terminator
3792  // }
3793  // %ts = llvm.load %x
3794  // omp.tile <- (%canonloop) sizes(%ts : i32)
3795 
3796  return success();
3797 }
3798 
3799 std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3800  return getODSOperandIndexAndLength(odsIndex_applyees);
3801 }
3802 
3803 std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3804  return getODSOperandIndexAndLength(odsIndex_generatees);
3805 }
3806 
3807 //===----------------------------------------------------------------------===//
3808 // Critical construct (2.17.1)
3809 //===----------------------------------------------------------------------===//
3810 
3811 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3812  const CriticalDeclareOperands &clauses) {
3813  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3814 }
3815 
3816 LogicalResult CriticalDeclareOp::verify() {
3817  return verifySynchronizationHint(*this, getHint());
3818 }
3819 
3820 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3821  if (getNameAttr()) {
3822  SymbolRefAttr symbolRef = getNameAttr();
3823  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3824  *this, symbolRef);
3825  if (!decl) {
3826  return emitOpError() << "expected symbol reference " << symbolRef
3827  << " to point to a critical declaration";
3828  }
3829  }
3830 
3831  return success();
3832 }
3833 
3834 //===----------------------------------------------------------------------===//
3835 // Ordered construct
3836 //===----------------------------------------------------------------------===//
3837 
3838 static LogicalResult verifyOrderedParent(Operation &op) {
3839  bool hasRegion = op.getNumRegions() > 0;
3840  auto loopOp = op.getParentOfType<LoopNestOp>();
3841  if (!loopOp) {
3842  if (hasRegion)
3843  return success();
3844 
3845  // TODO: Consider if this needs to be the case only for the standalone
3846  // variant of the ordered construct.
3847  return op.emitOpError() << "must be nested inside of a loop";
3848  }
3849 
3850  Operation *wrapper = loopOp->getParentOp();
3851  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3852  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3853  if (!orderedAttr)
3854  return op.emitOpError() << "the enclosing worksharing-loop region must "
3855  "have an ordered clause";
3856 
3857  if (hasRegion && orderedAttr.getInt() != 0)
3858  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3859  "have a parameter present";
3860 
3861  if (!hasRegion && orderedAttr.getInt() == 0)
3862  return op.emitOpError() << "the enclosing loop's ordered clause must "
3863  "have a parameter present";
3864  } else if (!isa<SimdOp>(wrapper)) {
3865  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3866  "or worksharing simd loop";
3867  }
3868  return success();
3869 }
3870 
3871 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3872  const OrderedOperands &clauses) {
3873  OrderedOp::build(builder, state, clauses.doacrossDependType,
3874  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3875 }
3876 
3877 LogicalResult OrderedOp::verify() {
3878  if (failed(verifyOrderedParent(**this)))
3879  return failure();
3880 
3881  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3882  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3883  return emitOpError() << "number of variables in depend clause does not "
3884  << "match number of iteration variables in the "
3885  << "doacross loop";
3886 
3887  return success();
3888 }
3889 
3890 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3891  const OrderedRegionOperands &clauses) {
3892  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3893 }
3894 
3895 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3896 
3897 //===----------------------------------------------------------------------===//
3898 // TaskwaitOp
3899 //===----------------------------------------------------------------------===//
3900 
3901 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3902  const TaskwaitOperands &clauses) {
3903  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3904  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3905  /*depend_vars=*/{}, /*nowait=*/nullptr);
3906 }
3907 
3908 //===----------------------------------------------------------------------===//
3909 // Verifier for AtomicReadOp
3910 //===----------------------------------------------------------------------===//
3911 
3912 LogicalResult AtomicReadOp::verify() {
3913  if (verifyCommon().failed())
3914  return mlir::failure();
3915 
3916  if (auto mo = getMemoryOrder()) {
3917  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3918  *mo == ClauseMemoryOrderKind::Release) {
3919  return emitError(
3920  "memory-order must not be acq_rel or release for atomic reads");
3921  }
3922  }
3923  return verifySynchronizationHint(*this, getHint());
3924 }
3925 
3926 //===----------------------------------------------------------------------===//
3927 // Verifier for AtomicWriteOp
3928 //===----------------------------------------------------------------------===//
3929 
3930 LogicalResult AtomicWriteOp::verify() {
3931  if (verifyCommon().failed())
3932  return mlir::failure();
3933 
3934  if (auto mo = getMemoryOrder()) {
3935  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3936  *mo == ClauseMemoryOrderKind::Acquire) {
3937  return emitError(
3938  "memory-order must not be acq_rel or acquire for atomic writes");
3939  }
3940  }
3941  return verifySynchronizationHint(*this, getHint());
3942 }
3943 
3944 //===----------------------------------------------------------------------===//
3945 // Verifier for AtomicUpdateOp
3946 //===----------------------------------------------------------------------===//
3947 
3948 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3949  PatternRewriter &rewriter) {
3950  if (op.isNoOp()) {
3951  rewriter.eraseOp(op);
3952  return success();
3953  }
3954  if (Value writeVal = op.getWriteOpVal()) {
3955  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3956  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3957  return success();
3958  }
3959  return failure();
3960 }
3961 
3962 LogicalResult AtomicUpdateOp::verify() {
3963  if (verifyCommon().failed())
3964  return mlir::failure();
3965 
3966  if (auto mo = getMemoryOrder()) {
3967  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3968  *mo == ClauseMemoryOrderKind::Acquire) {
3969  return emitError(
3970  "memory-order must not be acq_rel or acquire for atomic updates");
3971  }
3972  }
3973 
3974  return verifySynchronizationHint(*this, getHint());
3975 }
3976 
3977 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3978 
3979 //===----------------------------------------------------------------------===//
3980 // Verifier for AtomicCaptureOp
3981 //===----------------------------------------------------------------------===//
3982 
3983 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3984  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3985  return op;
3986  return dyn_cast<AtomicReadOp>(getSecondOp());
3987 }
3988 
3989 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3990  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3991  return op;
3992  return dyn_cast<AtomicWriteOp>(getSecondOp());
3993 }
3994 
3995 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3996  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3997  return op;
3998  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3999 }
4000 
4001 LogicalResult AtomicCaptureOp::verify() {
4002  return verifySynchronizationHint(*this, getHint());
4003 }
4004 
4005 LogicalResult AtomicCaptureOp::verifyRegions() {
4006  if (verifyRegionsCommon().failed())
4007  return mlir::failure();
4008 
4009  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4010  return emitOpError(
4011  "operations inside capture region must not have hint clause");
4012 
4013  if (getFirstOp()->getAttr("memory_order") ||
4014  getSecondOp()->getAttr("memory_order"))
4015  return emitOpError(
4016  "operations inside capture region must not have memory_order clause");
4017  return success();
4018 }
4019 
4020 //===----------------------------------------------------------------------===//
4021 // CancelOp
4022 //===----------------------------------------------------------------------===//
4023 
4024 void CancelOp::build(OpBuilder &builder, OperationState &state,
4025  const CancelOperands &clauses) {
4026  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4027 }
4028 
4030  Operation *parent = thisOp->getParentOp();
4031  while (parent) {
4032  if (parent->getDialect() == thisOp->getDialect())
4033  return parent;
4034  parent = parent->getParentOp();
4035  }
4036  return nullptr;
4037 }
4038 
4039 LogicalResult CancelOp::verify() {
4040  ClauseCancellationConstructType cct = getCancelDirective();
4041  // The next OpenMP operation in the chain of parents
4042  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4043  if (!structuralParent)
4044  return emitOpError() << "Orphaned cancel construct";
4045 
4046  if ((cct == ClauseCancellationConstructType::Parallel) &&
4047  !mlir::isa<ParallelOp>(structuralParent)) {
4048  return emitOpError() << "cancel parallel must appear "
4049  << "inside a parallel region";
4050  }
4051  if (cct == ClauseCancellationConstructType::Loop) {
4052  // structural parent will be omp.loop_nest, directly nested inside
4053  // omp.wsloop
4054  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4055 
4056  if (!wsloopOp) {
4057  return emitOpError()
4058  << "cancel loop must appear inside a worksharing-loop region";
4059  }
4060  if (wsloopOp.getNowaitAttr()) {
4061  return emitError() << "A worksharing construct that is canceled "
4062  << "must not have a nowait clause";
4063  }
4064  if (wsloopOp.getOrderedAttr()) {
4065  return emitError() << "A worksharing construct that is canceled "
4066  << "must not have an ordered clause";
4067  }
4068 
4069  } else if (cct == ClauseCancellationConstructType::Sections) {
4070  // structural parent will be an omp.section, directly nested inside
4071  // omp.sections
4072  auto sectionsOp =
4073  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4074  if (!sectionsOp) {
4075  return emitOpError() << "cancel sections must appear "
4076  << "inside a sections region";
4077  }
4078  if (sectionsOp.getNowait()) {
4079  return emitError() << "A sections construct that is canceled "
4080  << "must not have a nowait clause";
4081  }
4082  }
4083  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4084  (!mlir::isa<omp::TaskOp>(structuralParent) &&
4085  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4086  return emitOpError() << "cancel taskgroup must appear "
4087  << "inside a task region";
4088  }
4089  return success();
4090 }
4091 
4092 //===----------------------------------------------------------------------===//
4093 // CancellationPointOp
4094 //===----------------------------------------------------------------------===//
4095 
4096 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4097  const CancellationPointOperands &clauses) {
4098  CancellationPointOp::build(builder, state, clauses.cancelDirective);
4099 }
4100 
4101 LogicalResult CancellationPointOp::verify() {
4102  ClauseCancellationConstructType cct = getCancelDirective();
4103  // The next OpenMP operation in the chain of parents
4104  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4105  if (!structuralParent)
4106  return emitOpError() << "Orphaned cancellation point";
4107 
4108  if ((cct == ClauseCancellationConstructType::Parallel) &&
4109  !mlir::isa<ParallelOp>(structuralParent)) {
4110  return emitOpError() << "cancellation point parallel must appear "
4111  << "inside a parallel region";
4112  }
4113  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4114  // find the wsloop
4115  if ((cct == ClauseCancellationConstructType::Loop) &&
4116  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4117  return emitOpError() << "cancellation point loop must appear "
4118  << "inside a worksharing-loop region";
4119  }
4120  if ((cct == ClauseCancellationConstructType::Sections) &&
4121  !mlir::isa<omp::SectionOp>(structuralParent)) {
4122  return emitOpError() << "cancellation point sections must appear "
4123  << "inside a sections region";
4124  }
4125  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4126  !mlir::isa<omp::TaskOp>(structuralParent)) {
4127  return emitOpError() << "cancellation point taskgroup must appear "
4128  << "inside a task region";
4129  }
4130  return success();
4131 }
4132 
4133 //===----------------------------------------------------------------------===//
4134 // MapBoundsOp
4135 //===----------------------------------------------------------------------===//
4136 
4137 LogicalResult MapBoundsOp::verify() {
4138  auto extent = getExtent();
4139  auto upperbound = getUpperBound();
4140  if (!extent && !upperbound)
4141  return emitError("expected extent or upperbound.");
4142  return success();
4143 }
4144 
4145 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4146  TypeRange /*result_types*/, StringAttr symName,
4147  TypeAttr type) {
4148  PrivateClauseOp::build(
4149  odsBuilder, odsState, symName, type,
4151  DataSharingClauseType::Private));
4152 }
4153 
4154 LogicalResult PrivateClauseOp::verifyRegions() {
4155  Type argType = getArgType();
4156  auto verifyTerminator = [&](Operation *terminator,
4157  bool yieldsValue) -> LogicalResult {
4158  if (!terminator->getBlock()->getSuccessors().empty())
4159  return success();
4160 
4161  if (!llvm::isa<YieldOp>(terminator))
4162  return mlir::emitError(terminator->getLoc())
4163  << "expected exit block terminator to be an `omp.yield` op.";
4164 
4165  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4166  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4167 
4168  if (!yieldsValue) {
4169  if (yieldedTypes.empty())
4170  return success();
4171 
4172  return mlir::emitError(terminator->getLoc())
4173  << "Did not expect any values to be yielded.";
4174  }
4175 
4176  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4177  return success();
4178 
4179  auto error = mlir::emitError(yieldOp.getLoc())
4180  << "Invalid yielded value. Expected type: " << argType
4181  << ", got: ";
4182 
4183  if (yieldedTypes.empty())
4184  error << "None";
4185  else
4186  error << yieldedTypes;
4187 
4188  return error;
4189  };
4190 
4191  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4192  StringRef regionName,
4193  bool yieldsValue) -> LogicalResult {
4194  assert(!region.empty());
4195 
4196  if (region.getNumArguments() != expectedNumArgs)
4197  return mlir::emitError(region.getLoc())
4198  << "`" << regionName << "`: "
4199  << "expected " << expectedNumArgs
4200  << " region arguments, got: " << region.getNumArguments();
4201 
4202  for (Block &block : region) {
4203  // MLIR will verify the absence of the terminator for us.
4204  if (!block.mightHaveTerminator())
4205  continue;
4206 
4207  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4208  return failure();
4209  }
4210 
4211  return success();
4212  };
4213 
4214  // Ensure all of the region arguments have the same type
4215  for (Region *region : getRegions())
4216  for (Type ty : region->getArgumentTypes())
4217  if (ty != argType)
4218  return emitError() << "Region argument type mismatch: got " << ty
4219  << " expected " << argType << ".";
4220 
4221  mlir::Region &initRegion = getInitRegion();
4222  if (!initRegion.empty() &&
4223  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4224  /*yieldsValue=*/true)))
4225  return failure();
4226 
4227  DataSharingClauseType dsType = getDataSharingType();
4228 
4229  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4230  return emitError("`private` clauses do not require a `copy` region.");
4231 
4232  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4233  return emitError(
4234  "`firstprivate` clauses require at least a `copy` region.");
4235 
4236  if (dsType == DataSharingClauseType::FirstPrivate &&
4237  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4238  /*yieldsValue=*/true)))
4239  return failure();
4240 
4241  if (!getDeallocRegion().empty() &&
4242  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4243  /*yieldsValue=*/false)))
4244  return failure();
4245 
4246  return success();
4247 }
4248 
4249 //===----------------------------------------------------------------------===//
4250 // Spec 5.2: Masked construct (10.5)
4251 //===----------------------------------------------------------------------===//
4252 
4253 void MaskedOp::build(OpBuilder &builder, OperationState &state,
4254  const MaskedOperands &clauses) {
4255  MaskedOp::build(builder, state, clauses.filteredThreadId);
4256 }
4257 
4258 //===----------------------------------------------------------------------===//
4259 // Spec 5.2: Scan construct (5.6)
4260 //===----------------------------------------------------------------------===//
4261 
4262 void ScanOp::build(OpBuilder &builder, OperationState &state,
4263  const ScanOperands &clauses) {
4264  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4265 }
4266 
4267 LogicalResult ScanOp::verify() {
4268  if (hasExclusiveVars() == hasInclusiveVars())
4269  return emitError(
4270  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4271  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4272  if (parentWsLoopOp.getReductionModAttr() &&
4273  parentWsLoopOp.getReductionModAttr().getValue() ==
4274  ReductionModifier::inscan)
4275  return success();
4276  }
4277  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4278  if (parentSimdOp.getReductionModAttr() &&
4279  parentSimdOp.getReductionModAttr().getValue() ==
4280  ReductionModifier::inscan)
4281  return success();
4282  }
4283  return emitError("SCAN directive needs to be enclosed within a parent "
4284  "worksharing loop construct or SIMD construct with INSCAN "
4285  "reduction modifier");
4286 }
4287 
4288 /// Verifies align clause in allocate directive
4289 
4290 LogicalResult AllocateDirOp::verify() {
4291  std::optional<uint64_t> align = this->getAlign();
4292 
4293  if (align.has_value()) {
4294  if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4295  return emitError() << "ALIGN value : " << align.value()
4296  << " must be power of 2";
4297  }
4298 
4299  return success();
4300 }
4301 
4302 //===----------------------------------------------------------------------===//
4303 // TargetAllocMemOp
4304 //===----------------------------------------------------------------------===//
4305 
4306 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4307  return getInTypeAttr().getValue();
4308 }
4309 
4310 /// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4311 /// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4312 /// attr-dict-without-keyword
4313 static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4314  mlir::OperationState &result) {
4315  auto &builder = parser.getBuilder();
4316  bool hasOperands = false;
4317  std::int32_t typeparamsSize = 0;
4318 
4319  // Parse device number as a new operand
4321  mlir::Type deviceType;
4322  if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4323  return mlir::failure();
4324  if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4325  return mlir::failure();
4326  if (parser.parseComma())
4327  return mlir::failure();
4328 
4329  mlir::Type intype;
4330  if (parser.parseType(intype))
4331  return mlir::failure();
4332  result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4335  if (!parser.parseOptionalLParen()) {
4336  // parse the LEN params of the derived type. (<params> : <types>)
4337  if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
4338  parser.parseColonTypeList(typeVec) || parser.parseRParen())
4339  return mlir::failure();
4340  typeparamsSize = operands.size();
4341  hasOperands = true;
4342  }
4343  std::int32_t shapeSize = 0;
4344  if (!parser.parseOptionalComma()) {
4345  // parse size to scale by, vector of n dimensions of type index
4347  return mlir::failure();
4348  shapeSize = operands.size() - typeparamsSize;
4349  auto idxTy = builder.getIndexType();
4350  for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4351  typeVec.push_back(idxTy);
4352  hasOperands = true;
4353  }
4354  if (hasOperands &&
4355  parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4356  result.operands))
4357  return mlir::failure();
4358 
4359  mlir::Type restype = builder.getIntegerType(64);
4360  if (!restype) {
4361  parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4362  return mlir::failure();
4363  }
4364  llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4365  result.addAttribute("operandSegmentSizes",
4366  builder.getDenseI32ArrayAttr(segmentSizes));
4367  if (parser.parseOptionalAttrDict(result.attributes) ||
4368  parser.addTypeToList(restype, result.types))
4369  return mlir::failure();
4370  return mlir::success();
4371 }
4372 
4373 mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4374  mlir::OperationState &result) {
4375  return parseTargetAllocMemOp(parser, result);
4376 }
4377 
4379  p << " ";
4380  p.printOperand(getDevice());
4381  p << " : ";
4382  p << getDevice().getType();
4383  p << ", ";
4384  p << getInType();
4385  if (!getTypeparams().empty()) {
4386  p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4387  }
4388  for (auto sh : getShape()) {
4389  p << ", ";
4390  p.printOperand(sh);
4391  }
4392  p.printOptionalAttrDict((*this)->getAttrs(),
4393  {"in_type", "operandSegmentSizes"});
4394 }
4395 
4396 llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4397  mlir::Type outType = getType();
4398  if (!mlir::dyn_cast<IntegerType>(outType))
4399  return emitOpError("must be a integer type");
4400  return mlir::success();
4401 }
4402 
4403 //===----------------------------------------------------------------------===//
4404 // WorkdistributeOp
4405 //===----------------------------------------------------------------------===//
4406 
4407 LogicalResult WorkdistributeOp::verify() {
4408  // Check that region exists and is not empty
4409  Region &region = getRegion();
4410  if (region.empty())
4411  return emitOpError("region cannot be empty");
4412  // Verify single entry point.
4413  Block &entryBlock = region.front();
4414  if (entryBlock.empty())
4415  return emitOpError("region must contain a structured block");
4416  // Verify single exit point.
4417  bool hasTerminator = false;
4418  for (Block &block : region) {
4419  if (isa<TerminatorOp>(block.back())) {
4420  if (hasTerminator) {
4421  return emitOpError("region must have exactly one terminator");
4422  }
4423  hasTerminator = true;
4424  }
4425  }
4426  if (!hasTerminator) {
4427  return emitOpError("region must be terminated with omp.terminator");
4428  }
4429  auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4430  // No implicit barrier at end
4431  if (isa<BarrierOp>(op)) {
4432  return emitOpError(
4433  "explicit barriers are not allowed in workdistribute region");
4434  }
4435  // Check for invalid nested constructs
4436  if (isa<ParallelOp>(op)) {
4437  return emitOpError(
4438  "nested parallel constructs not allowed in workdistribute");
4439  }
4440  if (isa<TeamsOp>(op)) {
4441  return emitOpError(
4442  "nested teams constructs not allowed in workdistribute");
4443  }
4444  return WalkResult::advance();
4445  });
4446  if (walkResult.wasInterrupted())
4447  return failure();
4448 
4449  Operation *parentOp = (*this)->getParentOp();
4450  if (!llvm::dyn_cast<TeamsOp>(parentOp))
4451  return emitOpError("workdistribute must be nested under teams");
4452  return success();
4453 }
4454 
4455 #define GET_ATTRDEF_CLASSES
4456 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4457 
4458 #define GET_OP_CLASSES
4459 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4460 
4461 #define GET_TYPEDEF_CLASSES
4462 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1367
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static void printMapClause(OpAsmPrinter &p, Operation *op, ClauseMapFlagsAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseMapClause(OpAsmParser &parser, ClauseMapFlagsAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
Definition: ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:270
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Definition: Diagnostics.h:230
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:354
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:769
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:306
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
bool isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
Definition: LoopUtils.cpp:1361
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1584
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:881
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.