MLIR  22.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/SymbolTable.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include <cstddef>
37 #include <iterator>
38 #include <optional>
39 #include <variant>
40 
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::omp;
48 
49 static ArrayAttr makeArrayAttr(MLIRContext *context,
51  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52 }
53 
54 static DenseBoolArrayAttr
56  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57 }
58 
59 namespace {
60 struct MemRefPointerLikeModel
61  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
62  MemRefType> {
63  Type getElementType(Type pointer) const {
64  return llvm::cast<MemRefType>(pointer).getElementType();
65  }
66 };
67 
68 struct LLVMPointerPointerLikeModel
69  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
70  LLVM::LLVMPointerType> {
71  Type getElementType(Type pointer) const { return Type(); }
72 };
73 } // namespace
74 
75 void OpenMPDialect::initialize() {
76  addOperations<
77 #define GET_OP_LIST
78 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
79  >();
80  addAttributes<
81 #define GET_ATTRDEF_LIST
82 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
83  >();
84  addTypes<
85 #define GET_TYPEDEF_LIST
86 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
87  >();
88 
89  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
90 
91  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
92  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
93  *getContext());
94 
95  // Attach default offload module interface to module op to access
96  // offload functionality through
97  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
98  *getContext());
99 
100  // Attach default declare target interfaces to operations which can be marked
101  // as declare target (Global Operations and Functions/Subroutines in dialects
102  // that Fortran (or other languages that lower to MLIR) translates too
103  mlir::LLVM::GlobalOp::attachInterface<
105  *getContext());
106  mlir::LLVM::LLVMFuncOp::attachInterface<
108  *getContext());
109  mlir::func::FuncOp::attachInterface<
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Parser and printer for Allocate Clause
115 //===----------------------------------------------------------------------===//
116 
117 /// Parse an allocate clause with allocators and a list of operands with types.
118 ///
119 /// allocate-operand-list :: = allocate-operand |
120 /// allocator-operand `,` allocate-operand-list
121 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
122 /// ssa-id-and-type ::= ssa-id `:` type
123 static ParseResult parseAllocateAndAllocator(
124  OpAsmParser &parser,
126  SmallVectorImpl<Type> &allocateTypes,
128  SmallVectorImpl<Type> &allocatorTypes) {
129 
130  return parser.parseCommaSeparatedList([&]() {
132  Type type;
133  if (parser.parseOperand(operand) || parser.parseColonType(type))
134  return failure();
135  allocatorVars.push_back(operand);
136  allocatorTypes.push_back(type);
137  if (parser.parseArrow())
138  return failure();
139  if (parser.parseOperand(operand) || parser.parseColonType(type))
140  return failure();
141 
142  allocateVars.push_back(operand);
143  allocateTypes.push_back(type);
144  return success();
145  });
146 }
147 
148 /// Print allocate clause
150  OperandRange allocateVars,
151  TypeRange allocateTypes,
152  OperandRange allocatorVars,
153  TypeRange allocatorTypes) {
154  for (unsigned i = 0; i < allocateVars.size(); ++i) {
155  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
156  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
157  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
158  }
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Parser and printer for a clause attribute (StringEnumAttr)
163 //===----------------------------------------------------------------------===//
164 
165 template <typename ClauseAttr>
166 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
167  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
168  StringRef enumStr;
169  SMLoc loc = parser.getCurrentLocation();
170  if (parser.parseKeyword(&enumStr))
171  return failure();
172  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
173  attr = ClauseAttr::get(parser.getContext(), *enumValue);
174  return success();
175  }
176  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
177 }
178 
179 template <typename ClauseAttr>
180 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
181  p << stringifyEnum(attr.getValue());
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // Parser and printer for Linear Clause
186 //===----------------------------------------------------------------------===//
187 
188 /// linear ::= `linear` `(` linear-list `)`
189 /// linear-list := linear-val | linear-val linear-list
190 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
191 static ParseResult parseLinearClause(
192  OpAsmParser &parser,
194  SmallVectorImpl<Type> &linearTypes,
196  return parser.parseCommaSeparatedList([&]() {
198  Type type;
200  if (parser.parseOperand(var) || parser.parseEqual() ||
201  parser.parseOperand(stepVar) || parser.parseColonType(type))
202  return failure();
203 
204  linearVars.push_back(var);
205  linearTypes.push_back(type);
206  linearStepVars.push_back(stepVar);
207  return success();
208  });
209 }
210 
211 /// Print Linear Clause
213  ValueRange linearVars, TypeRange linearTypes,
214  ValueRange linearStepVars) {
215  size_t linearVarsSize = linearVars.size();
216  for (unsigned i = 0; i < linearVarsSize; ++i) {
217  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
218  p << linearVars[i];
219  if (linearStepVars.size() > i)
220  p << " = " << linearStepVars[i];
221  p << " : " << linearVars[i].getType() << separator;
222  }
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // Verifier for Nontemporal Clause
227 //===----------------------------------------------------------------------===//
228 
229 static LogicalResult verifyNontemporalClause(Operation *op,
230  OperandRange nontemporalVars) {
231 
232  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
233  DenseSet<Value> nontemporalItems;
234  for (const auto &it : nontemporalVars)
235  if (!nontemporalItems.insert(it).second)
236  return op->emitOpError() << "nontemporal variable used more than once";
237 
238  return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Parser, verifier and printer for Aligned Clause
243 //===----------------------------------------------------------------------===//
244 static LogicalResult verifyAlignedClause(Operation *op,
245  std::optional<ArrayAttr> alignments,
246  OperandRange alignedVars) {
247  // Check if number of alignment values equals to number of aligned variables
248  if (!alignedVars.empty()) {
249  if (!alignments || alignments->size() != alignedVars.size())
250  return op->emitOpError()
251  << "expected as many alignment values as aligned variables";
252  } else {
253  if (alignments)
254  return op->emitOpError() << "unexpected alignment values attribute";
255  return success();
256  }
257 
258  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
259  DenseSet<Value> alignedItems;
260  for (auto it : alignedVars)
261  if (!alignedItems.insert(it).second)
262  return op->emitOpError() << "aligned variable used more than once";
263 
264  if (!alignments)
265  return success();
266 
267  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
268  for (unsigned i = 0; i < (*alignments).size(); ++i) {
269  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
270  if (intAttr.getValue().sle(0))
271  return op->emitOpError() << "alignment should be greater than 0";
272  } else {
273  return op->emitOpError() << "expected integer alignment";
274  }
275  }
276 
277  return success();
278 }
279 
280 /// aligned ::= `aligned` `(` aligned-list `)`
281 /// aligned-list := aligned-val | aligned-val aligned-list
282 /// aligned-val := ssa-id-and-type `->` alignment
283 static ParseResult
286  SmallVectorImpl<Type> &alignedTypes,
287  ArrayAttr &alignmentsAttr) {
288  SmallVector<Attribute> alignmentVec;
289  if (failed(parser.parseCommaSeparatedList([&]() {
290  if (parser.parseOperand(alignedVars.emplace_back()) ||
291  parser.parseColonType(alignedTypes.emplace_back()) ||
292  parser.parseArrow() ||
293  parser.parseAttribute(alignmentVec.emplace_back())) {
294  return failure();
295  }
296  return success();
297  })))
298  return failure();
299  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
300  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
301  return success();
302 }
303 
304 /// Print Aligned Clause
306  ValueRange alignedVars, TypeRange alignedTypes,
307  std::optional<ArrayAttr> alignments) {
308  for (unsigned i = 0; i < alignedVars.size(); ++i) {
309  if (i != 0)
310  p << ", ";
311  p << alignedVars[i] << " : " << alignedVars[i].getType();
312  p << " -> " << (*alignments)[i];
313  }
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Parser, printer and verifier for Schedule Clause
318 //===----------------------------------------------------------------------===//
319 
320 static ParseResult
322  SmallVectorImpl<SmallString<12>> &modifiers) {
323  if (modifiers.size() > 2)
324  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
325  for (const auto &mod : modifiers) {
326  // Translate the string. If it has no value, then it was not a valid
327  // modifier!
328  auto symbol = symbolizeScheduleModifier(mod);
329  if (!symbol)
330  return parser.emitError(parser.getNameLoc())
331  << " unknown modifier type: " << mod;
332  }
333 
334  // If we have one modifier that is "simd", then stick a "none" modiifer in
335  // index 0.
336  if (modifiers.size() == 1) {
337  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
338  modifiers.push_back(modifiers[0]);
339  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
340  }
341  } else if (modifiers.size() == 2) {
342  // If there are two modifier:
343  // First modifier should not be simd, second one should be simd
344  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
345  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
346  return parser.emitError(parser.getNameLoc())
347  << " incorrect modifier order";
348  }
349  return success();
350 }
351 
352 /// schedule ::= `schedule` `(` sched-list `)`
353 /// sched-list ::= sched-val | sched-val sched-list |
354 /// sched-val `,` sched-modifier
355 /// sched-val ::= sched-with-chunk | sched-wo-chunk
356 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
357 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
358 /// sched-wo-chunk ::= `auto` | `runtime`
359 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
360 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
361 static ParseResult
362 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
363  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
364  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
365  Type &chunkType) {
366  StringRef keyword;
367  if (parser.parseKeyword(&keyword))
368  return failure();
369  std::optional<mlir::omp::ClauseScheduleKind> schedule =
370  symbolizeClauseScheduleKind(keyword);
371  if (!schedule)
372  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
373 
374  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
375  switch (*schedule) {
376  case ClauseScheduleKind::Static:
377  case ClauseScheduleKind::Dynamic:
378  case ClauseScheduleKind::Guided:
379  if (succeeded(parser.parseOptionalEqual())) {
380  chunkSize = OpAsmParser::UnresolvedOperand{};
381  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
382  return failure();
383  } else {
384  chunkSize = std::nullopt;
385  }
386  break;
387  case ClauseScheduleKind::Auto:
389  chunkSize = std::nullopt;
390  }
391 
392  // If there is a comma, we have one or more modifiers..
393  SmallVector<SmallString<12>> modifiers;
394  while (succeeded(parser.parseOptionalComma())) {
395  StringRef mod;
396  if (parser.parseKeyword(&mod))
397  return failure();
398  modifiers.push_back(mod);
399  }
400 
401  if (verifyScheduleModifiers(parser, modifiers))
402  return failure();
403 
404  if (!modifiers.empty()) {
405  SMLoc loc = parser.getCurrentLocation();
406  if (std::optional<ScheduleModifier> mod =
407  symbolizeScheduleModifier(modifiers[0])) {
408  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
409  } else {
410  return parser.emitError(loc, "invalid schedule modifier");
411  }
412  // Only SIMD attribute is allowed here!
413  if (modifiers.size() > 1) {
414  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
415  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
416  }
417  }
418 
419  return success();
420 }
421 
422 /// Print schedule clause
424  ClauseScheduleKindAttr scheduleKind,
425  ScheduleModifierAttr scheduleMod,
426  UnitAttr scheduleSimd, Value scheduleChunk,
427  Type scheduleChunkType) {
428  p << stringifyClauseScheduleKind(scheduleKind.getValue());
429  if (scheduleChunk)
430  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
431  if (scheduleMod)
432  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
433  if (scheduleSimd)
434  p << ", simd";
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // Parser and printer for Order Clause
439 //===----------------------------------------------------------------------===//
440 
441 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
442 // order-modifier ::= reproducible | unconstrained
443 static ParseResult parseOrderClause(OpAsmParser &parser,
444  ClauseOrderKindAttr &order,
445  OrderModifierAttr &orderMod) {
446  StringRef enumStr;
447  SMLoc loc = parser.getCurrentLocation();
448  if (parser.parseKeyword(&enumStr))
449  return failure();
450  if (std::optional<OrderModifier> enumValue =
451  symbolizeOrderModifier(enumStr)) {
452  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
453  if (parser.parseOptionalColon())
454  return failure();
455  loc = parser.getCurrentLocation();
456  if (parser.parseKeyword(&enumStr))
457  return failure();
458  }
459  if (std::optional<ClauseOrderKind> enumValue =
460  symbolizeClauseOrderKind(enumStr)) {
461  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
462  return success();
463  }
464  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
465 }
466 
468  ClauseOrderKindAttr order,
469  OrderModifierAttr orderMod) {
470  if (orderMod)
471  p << stringifyOrderModifier(orderMod.getValue()) << ":";
472  if (order)
473  p << stringifyClauseOrderKind(order.getValue());
474 }
475 
476 template <typename ClauseTypeAttr, typename ClauseType>
477 static ParseResult
478 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
479  std::optional<OpAsmParser::UnresolvedOperand> &operand,
480  Type &operandType,
481  std::optional<ClauseType> (*symbolizeClause)(StringRef),
482  StringRef clauseName) {
483  StringRef enumStr;
484  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
485  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
486  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
487  if (parser.parseComma())
488  return failure();
489  } else {
490  return parser.emitError(parser.getCurrentLocation())
491  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
492  ;
493  }
494  }
495 
497  if (succeeded(parser.parseOperand(var))) {
498  operand = var;
499  } else {
500  return parser.emitError(parser.getCurrentLocation())
501  << "expected " << clauseName << " operand";
502  }
503 
504  if (operand.has_value()) {
505  if (parser.parseColonType(operandType))
506  return failure();
507  }
508 
509  return success();
510 }
511 
512 template <typename ClauseTypeAttr, typename ClauseType>
513 static void
515  ClauseTypeAttr prescriptiveness, Value operand,
516  mlir::Type operandType,
517  StringRef (*stringifyClauseType)(ClauseType)) {
518 
519  if (prescriptiveness)
520  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
521 
522  if (operand)
523  p << operand << ": " << operandType;
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // Parser and printer for grainsize Clause
528 //===----------------------------------------------------------------------===//
529 
530 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
531 static ParseResult
532 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
533  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
534  Type &grainsizeType) {
535  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
536  parser, grainsizeMod, grainsize, grainsizeType,
537  &symbolizeClauseGrainsizeType, "grainsize");
538 }
539 
541  ClauseGrainsizeTypeAttr grainsizeMod,
542  Value grainsize, mlir::Type grainsizeType) {
543  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
544  p, op, grainsizeMod, grainsize, grainsizeType,
545  &stringifyClauseGrainsizeType);
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // Parser and printer for num_tasks Clause
550 //===----------------------------------------------------------------------===//
551 
552 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
553 static ParseResult
554 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
555  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
556  Type &numTasksType) {
557  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
558  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
559  "num_tasks");
560 }
561 
563  ClauseNumTasksTypeAttr numTasksMod,
564  Value numTasks, mlir::Type numTasksType) {
565  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
566  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Parsers for operations including clauses that define entry block arguments.
571 //===----------------------------------------------------------------------===//
572 
573 namespace {
574 struct MapParseArgs {
576  SmallVectorImpl<Type> &types;
578  SmallVectorImpl<Type> &types)
579  : vars(vars), types(types) {}
580 };
581 struct PrivateParseArgs {
584  ArrayAttr &syms;
585  UnitAttr &needsBarrier;
586  DenseI64ArrayAttr *mapIndices;
588  SmallVectorImpl<Type> &types, ArrayAttr &syms,
589  UnitAttr &needsBarrier,
590  DenseI64ArrayAttr *mapIndices = nullptr)
591  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
592  mapIndices(mapIndices) {}
593 };
594 
595 struct ReductionParseArgs {
597  SmallVectorImpl<Type> &types;
598  DenseBoolArrayAttr &byref;
599  ArrayAttr &syms;
600  ReductionModifierAttr *modifier;
601  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
603  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
604  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
605 };
606 
607 struct AllRegionParseArgs {
608  std::optional<MapParseArgs> hasDeviceAddrArgs;
609  std::optional<MapParseArgs> hostEvalArgs;
610  std::optional<ReductionParseArgs> inReductionArgs;
611  std::optional<MapParseArgs> mapArgs;
612  std::optional<PrivateParseArgs> privateArgs;
613  std::optional<ReductionParseArgs> reductionArgs;
614  std::optional<ReductionParseArgs> taskReductionArgs;
615  std::optional<MapParseArgs> useDeviceAddrArgs;
616  std::optional<MapParseArgs> useDevicePtrArgs;
617 };
618 } // namespace
619 
620 static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
621  return "private_barrier";
622 }
623 
624 static ParseResult parseClauseWithRegionArgs(
625  OpAsmParser &parser,
627  SmallVectorImpl<Type> &types,
628  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
629  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
630  DenseBoolArrayAttr *byref = nullptr,
631  ReductionModifierAttr *modifier = nullptr,
632  UnitAttr *needsBarrier = nullptr) {
633  SmallVector<SymbolRefAttr> symbolVec;
634  SmallVector<int64_t> mapIndicesVec;
635  SmallVector<bool> isByRefVec;
636  unsigned regionArgOffset = regionPrivateArgs.size();
637 
638  if (parser.parseLParen())
639  return failure();
640 
641  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
642  StringRef enumStr;
643  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
644  parser.parseComma())
645  return failure();
646  std::optional<ReductionModifier> enumValue =
647  symbolizeReductionModifier(enumStr);
648  if (!enumValue.has_value())
649  return failure();
650  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
651  if (!*modifier)
652  return failure();
653  }
654 
655  if (parser.parseCommaSeparatedList([&]() {
656  if (byref)
657  isByRefVec.push_back(
658  parser.parseOptionalKeyword("byref").succeeded());
659 
660  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
661  return failure();
662 
663  if (parser.parseOperand(operands.emplace_back()) ||
664  parser.parseArrow() ||
665  parser.parseArgument(regionPrivateArgs.emplace_back()))
666  return failure();
667 
668  if (mapIndices) {
669  if (parser.parseOptionalLSquare().succeeded()) {
670  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
671  parser.parseInteger(mapIndicesVec.emplace_back()) ||
672  parser.parseRSquare())
673  return failure();
674  } else {
675  mapIndicesVec.push_back(-1);
676  }
677  }
678 
679  return success();
680  }))
681  return failure();
682 
683  if (parser.parseColon())
684  return failure();
685 
686  if (parser.parseCommaSeparatedList([&]() {
687  if (parser.parseType(types.emplace_back()))
688  return failure();
689 
690  return success();
691  }))
692  return failure();
693 
694  if (operands.size() != types.size())
695  return failure();
696 
697  if (parser.parseRParen())
698  return failure();
699 
700  if (needsBarrier) {
702  .succeeded())
703  *needsBarrier = mlir::UnitAttr::get(parser.getContext());
704  }
705 
706  auto *argsBegin = regionPrivateArgs.begin();
707  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
708  argsBegin + regionArgOffset + types.size());
709  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
710  prv.type = type;
711  }
712 
713  if (symbols) {
714  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
715  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
716  }
717 
718  if (!mapIndicesVec.empty())
719  *mapIndices =
720  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
721 
722  if (byref)
723  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
724 
725  return success();
726 }
727 
728 static ParseResult parseBlockArgClause(
729  OpAsmParser &parser,
731  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
732  if (succeeded(parser.parseOptionalKeyword(keyword))) {
733  if (!mapArgs)
734  return failure();
735 
736  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
737  entryBlockArgs)))
738  return failure();
739  }
740  return success();
741 }
742 
743 static ParseResult parseBlockArgClause(
744  OpAsmParser &parser,
746  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
747  if (succeeded(parser.parseOptionalKeyword(keyword))) {
748  if (!privateArgs)
749  return failure();
750 
751  if (failed(parseClauseWithRegionArgs(
752  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
753  &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
754  /*modifier=*/nullptr, &privateArgs->needsBarrier)))
755  return failure();
756  }
757  return success();
758 }
759 
760 static ParseResult parseBlockArgClause(
761  OpAsmParser &parser,
763  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
764  if (succeeded(parser.parseOptionalKeyword(keyword))) {
765  if (!reductionArgs)
766  return failure();
767  if (failed(parseClauseWithRegionArgs(
768  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
769  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
770  reductionArgs->modifier)))
771  return failure();
772  }
773  return success();
774 }
775 
776 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
777  AllRegionParseArgs args) {
779 
780  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
781  args.hasDeviceAddrArgs)))
782  return parser.emitError(parser.getCurrentLocation())
783  << "invalid `has_device_addr` format";
784 
785  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
786  args.hostEvalArgs)))
787  return parser.emitError(parser.getCurrentLocation())
788  << "invalid `host_eval` format";
789 
790  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
791  args.inReductionArgs)))
792  return parser.emitError(parser.getCurrentLocation())
793  << "invalid `in_reduction` format";
794 
795  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
796  args.mapArgs)))
797  return parser.emitError(parser.getCurrentLocation())
798  << "invalid `map_entries` format";
799 
800  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
801  args.privateArgs)))
802  return parser.emitError(parser.getCurrentLocation())
803  << "invalid `private` format";
804 
805  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
806  args.reductionArgs)))
807  return parser.emitError(parser.getCurrentLocation())
808  << "invalid `reduction` format";
809 
810  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
811  args.taskReductionArgs)))
812  return parser.emitError(parser.getCurrentLocation())
813  << "invalid `task_reduction` format";
814 
815  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
816  args.useDeviceAddrArgs)))
817  return parser.emitError(parser.getCurrentLocation())
818  << "invalid `use_device_addr` format";
819 
820  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
821  args.useDevicePtrArgs)))
822  return parser.emitError(parser.getCurrentLocation())
823  << "invalid `use_device_addr` format";
824 
825  return parser.parseRegion(region, entryBlockArgs);
826 }
827 
828 // These parseXyz functions correspond to the custom<Xyz> definitions
829 // in the .td file(s).
830 static ParseResult parseTargetOpRegion(
831  OpAsmParser &parser, Region &region,
833  SmallVectorImpl<Type> &hasDeviceAddrTypes,
835  SmallVectorImpl<Type> &hostEvalTypes,
837  SmallVectorImpl<Type> &inReductionTypes,
838  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
840  SmallVectorImpl<Type> &mapTypes,
842  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
843  UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
844  AllRegionParseArgs args;
845  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
846  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
847  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
848  inReductionByref, inReductionSyms);
849  args.mapArgs.emplace(mapVars, mapTypes);
850  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
851  privateNeedsBarrier, &privateMaps);
852  return parseBlockArgRegion(parser, region, args);
853 }
854 
855 static ParseResult parseInReductionPrivateRegion(
856  OpAsmParser &parser, Region &region,
858  SmallVectorImpl<Type> &inReductionTypes,
859  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
861  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
862  UnitAttr &privateNeedsBarrier) {
863  AllRegionParseArgs args;
864  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
865  inReductionByref, inReductionSyms);
866  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
867  privateNeedsBarrier);
868  return parseBlockArgRegion(parser, region, args);
869 }
870 
872  OpAsmParser &parser, Region &region,
874  SmallVectorImpl<Type> &inReductionTypes,
875  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
877  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
878  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
880  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
881  ArrayAttr &reductionSyms) {
882  AllRegionParseArgs args;
883  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
884  inReductionByref, inReductionSyms);
885  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
886  privateNeedsBarrier);
887  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
888  reductionSyms, &reductionMod);
889  return parseBlockArgRegion(parser, region, args);
890 }
891 
892 static ParseResult parsePrivateRegion(
893  OpAsmParser &parser, Region &region,
895  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
896  UnitAttr &privateNeedsBarrier) {
897  AllRegionParseArgs args;
898  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
899  privateNeedsBarrier);
900  return parseBlockArgRegion(parser, region, args);
901 }
902 
903 static ParseResult parsePrivateReductionRegion(
904  OpAsmParser &parser, Region &region,
906  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
907  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
909  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
910  ArrayAttr &reductionSyms) {
911  AllRegionParseArgs args;
912  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
913  privateNeedsBarrier);
914  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
915  reductionSyms, &reductionMod);
916  return parseBlockArgRegion(parser, region, args);
917 }
918 
919 static ParseResult parseTaskReductionRegion(
920  OpAsmParser &parser, Region &region,
922  SmallVectorImpl<Type> &taskReductionTypes,
923  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
924  AllRegionParseArgs args;
925  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
926  taskReductionByref, taskReductionSyms);
927  return parseBlockArgRegion(parser, region, args);
928 }
929 
931  OpAsmParser &parser, Region &region,
933  SmallVectorImpl<Type> &useDeviceAddrTypes,
935  SmallVectorImpl<Type> &useDevicePtrTypes) {
936  AllRegionParseArgs args;
937  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
938  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
939  return parseBlockArgRegion(parser, region, args);
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // Printers for operations including clauses that define entry block arguments.
944 //===----------------------------------------------------------------------===//
945 
946 namespace {
947 struct MapPrintArgs {
948  ValueRange vars;
949  TypeRange types;
950  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
951 };
952 struct PrivatePrintArgs {
953  ValueRange vars;
954  TypeRange types;
955  ArrayAttr syms;
956  UnitAttr needsBarrier;
957  DenseI64ArrayAttr mapIndices;
958  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
959  UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
960  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
961  mapIndices(mapIndices) {}
962 };
963 struct ReductionPrintArgs {
964  ValueRange vars;
965  TypeRange types;
966  DenseBoolArrayAttr byref;
967  ArrayAttr syms;
968  ReductionModifierAttr modifier;
969  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
970  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
971  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
972 };
973 struct AllRegionPrintArgs {
974  std::optional<MapPrintArgs> hasDeviceAddrArgs;
975  std::optional<MapPrintArgs> hostEvalArgs;
976  std::optional<ReductionPrintArgs> inReductionArgs;
977  std::optional<MapPrintArgs> mapArgs;
978  std::optional<PrivatePrintArgs> privateArgs;
979  std::optional<ReductionPrintArgs> reductionArgs;
980  std::optional<ReductionPrintArgs> taskReductionArgs;
981  std::optional<MapPrintArgs> useDeviceAddrArgs;
982  std::optional<MapPrintArgs> useDevicePtrArgs;
983 };
984 } // namespace
985 
987  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
988  ValueRange argsSubrange, ValueRange operands, TypeRange types,
989  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
990  DenseBoolArrayAttr byref = nullptr,
991  ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
992  if (argsSubrange.empty())
993  return;
994 
995  p << clauseName << "(";
996 
997  if (modifier)
998  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
999 
1000  if (!symbols) {
1001  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1002  symbols = ArrayAttr::get(ctx, values);
1003  }
1004 
1005  if (!mapIndices) {
1006  llvm::SmallVector<int64_t> values(operands.size(), -1);
1007  mapIndices = DenseI64ArrayAttr::get(ctx, values);
1008  }
1009 
1010  if (!byref) {
1011  mlir::SmallVector<bool> values(operands.size(), false);
1012  byref = DenseBoolArrayAttr::get(ctx, values);
1013  }
1014 
1015  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1016  mapIndices.asArrayRef(),
1017  byref.asArrayRef()),
1018  p, [&p](auto t) {
1019  auto [op, arg, sym, map, isByRef] = t;
1020  if (isByRef)
1021  p << "byref ";
1022  if (sym)
1023  p << sym << " ";
1024 
1025  p << op << " -> " << arg;
1026 
1027  if (map != -1)
1028  p << " [map_idx=" << map << "]";
1029  });
1030  p << " : ";
1031  llvm::interleaveComma(types, p);
1032  p << ") ";
1033 
1034  if (needsBarrier)
1035  p << getPrivateNeedsBarrierSpelling() << " ";
1036 }
1037 
1039  StringRef clauseName, ValueRange argsSubrange,
1040  std::optional<MapPrintArgs> mapArgs) {
1041  if (mapArgs)
1042  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1043  mapArgs->types);
1044 }
1045 
1047  StringRef clauseName, ValueRange argsSubrange,
1048  std::optional<PrivatePrintArgs> privateArgs) {
1049  if (privateArgs)
1051  p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1052  privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1053  /*modifier=*/nullptr, privateArgs->needsBarrier);
1054 }
1055 
1056 static void
1057 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1058  ValueRange argsSubrange,
1059  std::optional<ReductionPrintArgs> reductionArgs) {
1060  if (reductionArgs)
1061  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1062  reductionArgs->vars, reductionArgs->types,
1063  reductionArgs->syms, /*mapIndices=*/nullptr,
1064  reductionArgs->byref, reductionArgs->modifier);
1065 }
1066 
1067 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1068  const AllRegionPrintArgs &args) {
1069  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1070  MLIRContext *ctx = op->getContext();
1071 
1072  printBlockArgClause(p, ctx, "has_device_addr",
1073  iface.getHasDeviceAddrBlockArgs(),
1074  args.hasDeviceAddrArgs);
1075  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1076  args.hostEvalArgs);
1077  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1078  args.inReductionArgs);
1079  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1080  args.mapArgs);
1081  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1082  args.privateArgs);
1083  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1084  args.reductionArgs);
1085  printBlockArgClause(p, ctx, "task_reduction",
1086  iface.getTaskReductionBlockArgs(),
1087  args.taskReductionArgs);
1088  printBlockArgClause(p, ctx, "use_device_addr",
1089  iface.getUseDeviceAddrBlockArgs(),
1090  args.useDeviceAddrArgs);
1091  printBlockArgClause(p, ctx, "use_device_ptr",
1092  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1093 
1094  p.printRegion(region, /*printEntryBlockArgs=*/false);
1095 }
1096 
1097 // These parseXyz functions correspond to the custom<Xyz> definitions
1098 // in the .td file(s).
1100  OpAsmPrinter &p, Operation *op, Region &region,
1101  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1102  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1103  ValueRange inReductionVars, TypeRange inReductionTypes,
1104  DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1105  ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1106  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1107  DenseI64ArrayAttr privateMaps) {
1108  AllRegionPrintArgs args;
1109  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1110  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1111  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1112  inReductionByref, inReductionSyms);
1113  args.mapArgs.emplace(mapVars, mapTypes);
1114  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1115  privateNeedsBarrier, privateMaps);
1116  printBlockArgRegion(p, op, region, args);
1117 }
1118 
1120  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1121  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1122  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1123  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1124  AllRegionPrintArgs args;
1125  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1126  inReductionByref, inReductionSyms);
1127  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1128  privateNeedsBarrier,
1129  /*mapIndices=*/nullptr);
1130  printBlockArgRegion(p, op, region, args);
1131 }
1132 
1134  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1135  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1136  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1137  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1138  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1139  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1140  ArrayAttr reductionSyms) {
1141  AllRegionPrintArgs args;
1142  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1143  inReductionByref, inReductionSyms);
1144  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1145  privateNeedsBarrier,
1146  /*mapIndices=*/nullptr);
1147  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1148  reductionSyms, reductionMod);
1149  printBlockArgRegion(p, op, region, args);
1150 }
1151 
1152 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1153  ValueRange privateVars, TypeRange privateTypes,
1154  ArrayAttr privateSyms,
1155  UnitAttr privateNeedsBarrier) {
1156  AllRegionPrintArgs args;
1157  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158  privateNeedsBarrier,
1159  /*mapIndices=*/nullptr);
1160  printBlockArgRegion(p, op, region, args);
1161 }
1162 
1164  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1165  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1166  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1167  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1168  ArrayAttr reductionSyms) {
1169  AllRegionPrintArgs args;
1170  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1171  privateNeedsBarrier,
1172  /*mapIndices=*/nullptr);
1173  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1174  reductionSyms, reductionMod);
1175  printBlockArgRegion(p, op, region, args);
1176 }
1177 
1179  Region &region,
1180  ValueRange taskReductionVars,
1181  TypeRange taskReductionTypes,
1182  DenseBoolArrayAttr taskReductionByref,
1183  ArrayAttr taskReductionSyms) {
1184  AllRegionPrintArgs args;
1185  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1186  taskReductionByref, taskReductionSyms);
1187  printBlockArgRegion(p, op, region, args);
1188 }
1189 
1191  Region &region,
1192  ValueRange useDeviceAddrVars,
1193  TypeRange useDeviceAddrTypes,
1194  ValueRange useDevicePtrVars,
1195  TypeRange useDevicePtrTypes) {
1196  AllRegionPrintArgs args;
1197  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1198  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1199  printBlockArgRegion(p, op, region, args);
1200 }
1201 
1202 /// Verifies Reduction Clause
1203 static LogicalResult
1204 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1205  OperandRange reductionVars,
1206  std::optional<ArrayRef<bool>> reductionByref) {
1207  if (!reductionVars.empty()) {
1208  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1209  return op->emitOpError()
1210  << "expected as many reduction symbol references "
1211  "as reduction variables";
1212  if (reductionByref && reductionByref->size() != reductionVars.size())
1213  return op->emitError() << "expected as many reduction variable by "
1214  "reference attributes as reduction variables";
1215  } else {
1216  if (reductionSyms)
1217  return op->emitOpError() << "unexpected reduction symbol references";
1218  return success();
1219  }
1220 
1221  // TODO: The followings should be done in
1222  // SymbolUserOpInterface::verifySymbolUses.
1223  DenseSet<Value> accumulators;
1224  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1225  Value accum = std::get<0>(args);
1226 
1227  if (!accumulators.insert(accum).second)
1228  return op->emitOpError() << "accumulator variable used more than once";
1229 
1230  Type varType = accum.getType();
1231  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1232  auto decl =
1233  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1234  if (!decl)
1235  return op->emitOpError() << "expected symbol reference " << symbolRef
1236  << " to point to a reduction declaration";
1237 
1238  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1239  return op->emitOpError()
1240  << "expected accumulator (" << varType
1241  << ") to be the same type as reduction declaration ("
1242  << decl.getAccumulatorType() << ")";
1243  }
1244 
1245  return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // Parser, printer and verifier for Copyprivate
1250 //===----------------------------------------------------------------------===//
1251 
1252 /// copyprivate-entry-list ::= copyprivate-entry
1253 /// | copyprivate-entry-list `,` copyprivate-entry
1254 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1255 static ParseResult parseCopyprivate(
1256  OpAsmParser &parser,
1258  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1260  if (failed(parser.parseCommaSeparatedList([&]() {
1261  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1262  parser.parseArrow() ||
1263  parser.parseAttribute(symsVec.emplace_back()) ||
1264  parser.parseColonType(copyprivateTypes.emplace_back()))
1265  return failure();
1266  return success();
1267  })))
1268  return failure();
1269  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1270  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1271  return success();
1272 }
1273 
1274 /// Print Copyprivate clause
1276  OperandRange copyprivateVars,
1277  TypeRange copyprivateTypes,
1278  std::optional<ArrayAttr> copyprivateSyms) {
1279  if (!copyprivateSyms.has_value())
1280  return;
1281  llvm::interleaveComma(
1282  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1283  [&](const auto &args) {
1284  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1285  << std::get<2>(args);
1286  });
1287 }
1288 
1289 /// Verifies CopyPrivate Clause
1290 static LogicalResult
1292  std::optional<ArrayAttr> copyprivateSyms) {
1293  size_t copyprivateSymsSize =
1294  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1295  if (copyprivateSymsSize != copyprivateVars.size())
1296  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1297  << copyprivateVars.size()
1298  << ") and functions (= " << copyprivateSymsSize
1299  << "), both must be equal";
1300  if (!copyprivateSyms.has_value())
1301  return success();
1302 
1303  for (auto copyprivateVarAndSym :
1304  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1305  auto symbolRef =
1306  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1307  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1308  funcOp;
1309  if (mlir::func::FuncOp mlirFuncOp =
1310  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1311  symbolRef))
1312  funcOp = mlirFuncOp;
1313  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1314  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1315  op, symbolRef))
1316  funcOp = llvmFuncOp;
1317 
1318  auto getNumArguments = [&] {
1319  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1320  };
1321 
1322  auto getArgumentType = [&](unsigned i) {
1323  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1324  *funcOp);
1325  };
1326 
1327  if (!funcOp)
1328  return op->emitOpError() << "expected symbol reference " << symbolRef
1329  << " to point to a copy function";
1330 
1331  if (getNumArguments() != 2)
1332  return op->emitOpError()
1333  << "expected copy function " << symbolRef << " to have 2 operands";
1334 
1335  Type argTy = getArgumentType(0);
1336  if (argTy != getArgumentType(1))
1337  return op->emitOpError() << "expected copy function " << symbolRef
1338  << " arguments to have the same type";
1339 
1340  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1341  if (argTy != varType)
1342  return op->emitOpError()
1343  << "expected copy function arguments' type (" << argTy
1344  << ") to be the same as copyprivate variable's type (" << varType
1345  << ")";
1346  }
1347 
1348  return success();
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // Parser, printer and verifier for DependVarList
1353 //===----------------------------------------------------------------------===//
1354 
1355 /// depend-entry-list ::= depend-entry
1356 /// | depend-entry-list `,` depend-entry
1357 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1358 static ParseResult
1361  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1363  if (failed(parser.parseCommaSeparatedList([&]() {
1364  StringRef keyword;
1365  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1366  parser.parseOperand(dependVars.emplace_back()) ||
1367  parser.parseColonType(dependTypes.emplace_back()))
1368  return failure();
1369  if (std::optional<ClauseTaskDepend> keywordDepend =
1370  (symbolizeClauseTaskDepend(keyword)))
1371  kindsVec.emplace_back(
1372  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1373  else
1374  return failure();
1375  return success();
1376  })))
1377  return failure();
1378  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1379  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1380  return success();
1381 }
1382 
1383 /// Print Depend clause
1385  OperandRange dependVars, TypeRange dependTypes,
1386  std::optional<ArrayAttr> dependKinds) {
1387 
1388  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1389  if (i != 0)
1390  p << ", ";
1391  p << stringifyClauseTaskDepend(
1392  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1393  .getValue())
1394  << " -> " << dependVars[i] << " : " << dependTypes[i];
1395  }
1396 }
1397 
1398 /// Verifies Depend clause
1399 static LogicalResult verifyDependVarList(Operation *op,
1400  std::optional<ArrayAttr> dependKinds,
1401  OperandRange dependVars) {
1402  if (!dependVars.empty()) {
1403  if (!dependKinds || dependKinds->size() != dependVars.size())
1404  return op->emitOpError() << "expected as many depend values"
1405  " as depend variables";
1406  } else {
1407  if (dependKinds && !dependKinds->empty())
1408  return op->emitOpError() << "unexpected depend values";
1409  return success();
1410  }
1411 
1412  return success();
1413 }
1414 
1415 //===----------------------------------------------------------------------===//
1416 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1417 //===----------------------------------------------------------------------===//
1418 
1419 /// Parses a Synchronization Hint clause. The value of hint is an integer
1420 /// which is a combination of different hints from `omp_sync_hint_t`.
1421 ///
1422 /// hint-clause = `hint` `(` hint-value `)`
1423 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1424  IntegerAttr &hintAttr) {
1425  StringRef hintKeyword;
1426  int64_t hint = 0;
1427  if (succeeded(parser.parseOptionalKeyword("none"))) {
1428  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1429  return success();
1430  }
1431  auto parseKeyword = [&]() -> ParseResult {
1432  if (failed(parser.parseKeyword(&hintKeyword)))
1433  return failure();
1434  if (hintKeyword == "uncontended")
1435  hint |= 1;
1436  else if (hintKeyword == "contended")
1437  hint |= 2;
1438  else if (hintKeyword == "nonspeculative")
1439  hint |= 4;
1440  else if (hintKeyword == "speculative")
1441  hint |= 8;
1442  else
1443  return parser.emitError(parser.getCurrentLocation())
1444  << hintKeyword << " is not a valid hint";
1445  return success();
1446  };
1447  if (parser.parseCommaSeparatedList(parseKeyword))
1448  return failure();
1449  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1450  return success();
1451 }
1452 
1453 /// Prints a Synchronization Hint clause
1455  IntegerAttr hintAttr) {
1456  int64_t hint = hintAttr.getInt();
1457 
1458  if (hint == 0) {
1459  p << "none";
1460  return;
1461  }
1462 
1463  // Helper function to get n-th bit from the right end of `value`
1464  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1465 
1466  bool uncontended = bitn(hint, 0);
1467  bool contended = bitn(hint, 1);
1468  bool nonspeculative = bitn(hint, 2);
1469  bool speculative = bitn(hint, 3);
1470 
1471  SmallVector<StringRef> hints;
1472  if (uncontended)
1473  hints.push_back("uncontended");
1474  if (contended)
1475  hints.push_back("contended");
1476  if (nonspeculative)
1477  hints.push_back("nonspeculative");
1478  if (speculative)
1479  hints.push_back("speculative");
1480 
1481  llvm::interleaveComma(hints, p);
1482 }
1483 
1484 /// Verifies a synchronization hint clause
1485 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1486 
1487  // Helper function to get n-th bit from the right end of `value`
1488  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1489 
1490  bool uncontended = bitn(hint, 0);
1491  bool contended = bitn(hint, 1);
1492  bool nonspeculative = bitn(hint, 2);
1493  bool speculative = bitn(hint, 3);
1494 
1495  if (uncontended && contended)
1496  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1497  "omp_sync_hint_contended cannot be combined";
1498  if (nonspeculative && speculative)
1499  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1500  "omp_sync_hint_speculative cannot be combined.";
1501  return success();
1502 }
1503 
1504 //===----------------------------------------------------------------------===//
1505 // Parser, printer and verifier for Target
1506 //===----------------------------------------------------------------------===//
1507 
1508 // Helper function to get bitwise AND of `value` and 'flag'
1509 uint64_t mapTypeToBitFlag(uint64_t value,
1510  llvm::omp::OpenMPOffloadMappingFlags flag) {
1511  return value & llvm::to_underlying(flag);
1512 }
1513 
1514 /// Parses a map_entries map type from a string format back into its numeric
1515 /// value.
1516 ///
1517 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1518 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1519 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1520  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1521  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1522 
1523  // This simply verifies the correct keyword is read in, the
1524  // keyword itself is stored inside of the operation
1525  auto parseTypeAndMod = [&]() -> ParseResult {
1526  StringRef mapTypeMod;
1527  if (parser.parseKeyword(&mapTypeMod))
1528  return failure();
1529 
1530  if (mapTypeMod == "always")
1531  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1532 
1533  if (mapTypeMod == "implicit")
1534  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1535 
1536  if (mapTypeMod == "ompx_hold")
1537  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1538 
1539  if (mapTypeMod == "close")
1540  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1541 
1542  if (mapTypeMod == "present")
1543  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1544 
1545  if (mapTypeMod == "to")
1546  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1547 
1548  if (mapTypeMod == "from")
1549  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1550 
1551  if (mapTypeMod == "tofrom")
1552  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1553  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1554 
1555  if (mapTypeMod == "delete")
1556  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1557 
1558  if (mapTypeMod == "return_param")
1559  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1560 
1561  return success();
1562  };
1563 
1564  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1565  return failure();
1566 
1567  mapType = parser.getBuilder().getIntegerAttr(
1568  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1569  llvm::to_underlying(mapTypeBits));
1570 
1571  return success();
1572 }
1573 
1574 /// Prints a map_entries map type from its numeric value out into its string
1575 /// format.
1577  IntegerAttr mapType) {
1578  uint64_t mapTypeBits = mapType.getUInt();
1579 
1580  bool emitAllocRelease = true;
1582 
1583  // handling of always, close, present placed at the beginning of the string
1584  // to aid readability
1585  if (mapTypeToBitFlag(mapTypeBits,
1586  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1587  mapTypeStrs.push_back("always");
1588  if (mapTypeToBitFlag(mapTypeBits,
1589  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1590  mapTypeStrs.push_back("implicit");
1591  if (mapTypeToBitFlag(mapTypeBits,
1592  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1593  mapTypeStrs.push_back("ompx_hold");
1594  if (mapTypeToBitFlag(mapTypeBits,
1595  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1596  mapTypeStrs.push_back("close");
1597  if (mapTypeToBitFlag(mapTypeBits,
1598  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1599  mapTypeStrs.push_back("present");
1600 
1601  // special handling of to/from/tofrom/delete and release/alloc, release +
1602  // alloc are the abscense of one of the other flags, whereas tofrom requires
1603  // both the to and from flag to be set.
1604  bool to = mapTypeToBitFlag(mapTypeBits,
1605  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1606  bool from = mapTypeToBitFlag(
1607  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1608  if (to && from) {
1609  emitAllocRelease = false;
1610  mapTypeStrs.push_back("tofrom");
1611  } else if (from) {
1612  emitAllocRelease = false;
1613  mapTypeStrs.push_back("from");
1614  } else if (to) {
1615  emitAllocRelease = false;
1616  mapTypeStrs.push_back("to");
1617  }
1618  if (mapTypeToBitFlag(mapTypeBits,
1619  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1620  emitAllocRelease = false;
1621  mapTypeStrs.push_back("delete");
1622  }
1623  if (mapTypeToBitFlag(
1624  mapTypeBits,
1625  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1626  emitAllocRelease = false;
1627  mapTypeStrs.push_back("return_param");
1628  }
1629  if (emitAllocRelease)
1630  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1631 
1632  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1633  p << mapTypeStrs[i];
1634  if (i + 1 < mapTypeStrs.size()) {
1635  p << ", ";
1636  }
1637  }
1638 }
1639 
1640 static ParseResult parseMembersIndex(OpAsmParser &parser,
1641  ArrayAttr &membersIdx) {
1642  SmallVector<Attribute> values, memberIdxs;
1643 
1644  auto parseIndices = [&]() -> ParseResult {
1645  int64_t value;
1646  if (parser.parseInteger(value))
1647  return failure();
1648  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1649  APInt(64, value, /*isSigned=*/false)));
1650  return success();
1651  };
1652 
1653  do {
1654  if (failed(parser.parseLSquare()))
1655  return failure();
1656 
1657  if (parser.parseCommaSeparatedList(parseIndices))
1658  return failure();
1659 
1660  if (failed(parser.parseRSquare()))
1661  return failure();
1662 
1663  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1664  values.clear();
1665  } while (succeeded(parser.parseOptionalComma()));
1666 
1667  if (!memberIdxs.empty())
1668  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1669 
1670  return success();
1671 }
1672 
1673 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1674  ArrayAttr membersIdx) {
1675  if (!membersIdx)
1676  return;
1677 
1678  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1679  p << "[";
1680  auto memberIdx = cast<ArrayAttr>(v);
1681  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1682  p << cast<IntegerAttr>(v2).getInt();
1683  });
1684  p << "]";
1685  });
1686 }
1687 
1689  VariableCaptureKindAttr mapCaptureType) {
1690  std::string typeCapStr;
1691  llvm::raw_string_ostream typeCap(typeCapStr);
1692  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1693  typeCap << "ByRef";
1694  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1695  typeCap << "ByCopy";
1696  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1697  typeCap << "VLAType";
1698  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1699  typeCap << "This";
1700  p << typeCapStr;
1701 }
1702 
1703 static ParseResult parseCaptureType(OpAsmParser &parser,
1704  VariableCaptureKindAttr &mapCaptureType) {
1705  StringRef mapCaptureKey;
1706  if (parser.parseKeyword(&mapCaptureKey))
1707  return failure();
1708 
1709  if (mapCaptureKey == "This")
1710  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1711  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1712  if (mapCaptureKey == "ByRef")
1713  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1714  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1715  if (mapCaptureKey == "ByCopy")
1716  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1717  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1718  if (mapCaptureKey == "VLAType")
1719  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1720  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1721 
1722  return success();
1723 }
1724 
1725 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1728 
1729  for (auto mapOp : mapVars) {
1730  if (!mapOp.getDefiningOp())
1731  return emitError(op->getLoc(), "missing map operation");
1732 
1733  if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1734  uint64_t mapTypeBits = mapInfoOp.getMapType();
1735 
1736  bool to = mapTypeToBitFlag(
1737  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1738  bool from = mapTypeToBitFlag(
1739  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1740  bool del = mapTypeToBitFlag(
1741  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1742 
1743  bool always = mapTypeToBitFlag(
1744  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1745  bool close = mapTypeToBitFlag(
1746  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1747  bool implicit = mapTypeToBitFlag(
1748  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1749 
1750  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1751  return emitError(op->getLoc(),
1752  "to, from, tofrom and alloc map types are permitted");
1753 
1754  if (isa<TargetEnterDataOp>(op) && (from || del))
1755  return emitError(op->getLoc(), "to and alloc map types are permitted");
1756 
1757  if (isa<TargetExitDataOp>(op) && to)
1758  return emitError(op->getLoc(),
1759  "from, release and delete map types are permitted");
1760 
1761  if (isa<TargetUpdateOp>(op)) {
1762  if (del) {
1763  return emitError(op->getLoc(),
1764  "at least one of to or from map types must be "
1765  "specified, other map types are not permitted");
1766  }
1767 
1768  if (!to && !from) {
1769  return emitError(op->getLoc(),
1770  "at least one of to or from map types must be "
1771  "specified, other map types are not permitted");
1772  }
1773 
1774  auto updateVar = mapInfoOp.getVarPtr();
1775 
1776  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1777  (from && updateToVars.contains(updateVar))) {
1778  return emitError(
1779  op->getLoc(),
1780  "either to or from map types can be specified, not both");
1781  }
1782 
1783  if (always || close || implicit) {
1784  return emitError(
1785  op->getLoc(),
1786  "present, mapper and iterator map type modifiers are permitted");
1787  }
1788 
1789  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1790  }
1791  } else if (!isa<DeclareMapperInfoOp>(op)) {
1792  return emitError(op->getLoc(),
1793  "map argument is not a map entry operation");
1794  }
1795  }
1796 
1797  return success();
1798 }
1799 
1800 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1801  std::optional<DenseI64ArrayAttr> privateMapIndices =
1802  targetOp.getPrivateMapsAttr();
1803 
1804  // None of the private operands are mapped.
1805  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1806  return success();
1807 
1808  OperandRange privateVars = targetOp.getPrivateVars();
1809 
1810  if (privateMapIndices.value().size() !=
1811  static_cast<int64_t>(privateVars.size()))
1812  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1813  "`private_maps` attribute mismatch");
1814 
1815  return success();
1816 }
1817 
1818 //===----------------------------------------------------------------------===//
1819 // MapInfoOp
1820 //===----------------------------------------------------------------------===//
1821 
1822 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1823  StringRef clauseName,
1824  OperandRange vars) {
1825  for (Value var : vars)
1826  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1827  return op->emitOpError()
1828  << "'" << clauseName
1829  << "' arguments must be defined by 'omp.map.info' ops";
1830  return success();
1831 }
1832 
1833 LogicalResult MapInfoOp::verify() {
1834  if (getMapperId() &&
1835  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1836  *this, getMapperIdAttr())) {
1837  return emitError("invalid mapper id");
1838  }
1839 
1840  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1841  return failure();
1842 
1843  return success();
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // TargetDataOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1851  const TargetDataOperands &clauses) {
1852  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1853  clauses.mapVars, clauses.useDeviceAddrVars,
1854  clauses.useDevicePtrVars);
1855 }
1856 
1857 LogicalResult TargetDataOp::verify() {
1858  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1859  getUseDeviceAddrVars().empty()) {
1860  return ::emitError(this->getLoc(),
1861  "At least one of map, use_device_ptr_vars, or "
1862  "use_device_addr_vars operand must be present");
1863  }
1864 
1865  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1866  getUseDevicePtrVars())))
1867  return failure();
1868 
1869  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1870  getUseDeviceAddrVars())))
1871  return failure();
1872 
1873  return verifyMapClause(*this, getMapVars());
1874 }
1875 
1876 //===----------------------------------------------------------------------===//
1877 // TargetEnterDataOp
1878 //===----------------------------------------------------------------------===//
1879 
1880 void TargetEnterDataOp::build(
1881  OpBuilder &builder, OperationState &state,
1882  const TargetEnterExitUpdateDataOperands &clauses) {
1883  MLIRContext *ctx = builder.getContext();
1884  TargetEnterDataOp::build(builder, state,
1885  makeArrayAttr(ctx, clauses.dependKinds),
1886  clauses.dependVars, clauses.device, clauses.ifExpr,
1887  clauses.mapVars, clauses.nowait);
1888 }
1889 
1890 LogicalResult TargetEnterDataOp::verify() {
1891  LogicalResult verifyDependVars =
1892  verifyDependVarList(*this, getDependKinds(), getDependVars());
1893  return failed(verifyDependVars) ? verifyDependVars
1894  : verifyMapClause(*this, getMapVars());
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // TargetExitDataOp
1899 //===----------------------------------------------------------------------===//
1900 
1901 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1902  const TargetEnterExitUpdateDataOperands &clauses) {
1903  MLIRContext *ctx = builder.getContext();
1904  TargetExitDataOp::build(builder, state,
1905  makeArrayAttr(ctx, clauses.dependKinds),
1906  clauses.dependVars, clauses.device, clauses.ifExpr,
1907  clauses.mapVars, clauses.nowait);
1908 }
1909 
1910 LogicalResult TargetExitDataOp::verify() {
1911  LogicalResult verifyDependVars =
1912  verifyDependVarList(*this, getDependKinds(), getDependVars());
1913  return failed(verifyDependVars) ? verifyDependVars
1914  : verifyMapClause(*this, getMapVars());
1915 }
1916 
1917 //===----------------------------------------------------------------------===//
1918 // TargetUpdateOp
1919 //===----------------------------------------------------------------------===//
1920 
1921 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1922  const TargetEnterExitUpdateDataOperands &clauses) {
1923  MLIRContext *ctx = builder.getContext();
1924  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1925  clauses.dependVars, clauses.device, clauses.ifExpr,
1926  clauses.mapVars, clauses.nowait);
1927 }
1928 
1929 LogicalResult TargetUpdateOp::verify() {
1930  LogicalResult verifyDependVars =
1931  verifyDependVarList(*this, getDependKinds(), getDependVars());
1932  return failed(verifyDependVars) ? verifyDependVars
1933  : verifyMapClause(*this, getMapVars());
1934 }
1935 
1936 //===----------------------------------------------------------------------===//
1937 // TargetOp
1938 //===----------------------------------------------------------------------===//
1939 
1940 void TargetOp::build(OpBuilder &builder, OperationState &state,
1941  const TargetOperands &clauses) {
1942  MLIRContext *ctx = builder.getContext();
1943  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1944  // inReductionByref, inReductionSyms.
1945  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1946  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1947  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1948  clauses.hostEvalVars, clauses.ifExpr,
1949  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1950  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1951  clauses.mapVars, clauses.nowait, clauses.privateVars,
1952  makeArrayAttr(ctx, clauses.privateSyms),
1953  clauses.privateNeedsBarrier, clauses.threadLimit,
1954  /*private_maps=*/nullptr);
1955 }
1956 
1957 LogicalResult TargetOp::verify() {
1958  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1959  return failure();
1960 
1961  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1962  getHasDeviceAddrVars())))
1963  return failure();
1964 
1965  if (failed(verifyMapClause(*this, getMapVars())))
1966  return failure();
1967 
1968  return verifyPrivateVarsMapping(*this);
1969 }
1970 
1971 LogicalResult TargetOp::verifyRegions() {
1972  auto teamsOps = getOps<TeamsOp>();
1973  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1974  return emitError("target containing multiple 'omp.teams' nested ops");
1975 
1976  // Check that host_eval values are only used in legal ways.
1977  Operation *capturedOp = getInnermostCapturedOmpOp();
1978  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1979  for (Value hostEvalArg :
1980  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1981  for (Operation *user : hostEvalArg.getUsers()) {
1982  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1983  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1984  teamsOp.getNumTeamsUpper(),
1985  teamsOp.getThreadLimit()},
1986  hostEvalArg))
1987  continue;
1988 
1989  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1990  "and 'thread_limit' in 'omp.teams'";
1991  }
1992  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1994  parallelOp->isAncestor(capturedOp) &&
1995  hostEvalArg == parallelOp.getNumThreads())
1996  continue;
1997 
1998  return emitOpError()
1999  << "host_eval argument only legal as 'num_threads' in "
2000  "'omp.parallel' when representing target SPMD";
2001  }
2002  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2004  loopNestOp.getOperation() == capturedOp &&
2005  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2006  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2007  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2008  continue;
2009 
2010  return emitOpError() << "host_eval argument only legal as loop bounds "
2011  "and steps in 'omp.loop_nest' when trip count "
2012  "must be evaluated in the host";
2013  }
2014 
2015  return emitOpError() << "host_eval argument illegal use in '"
2016  << user->getName() << "' operation";
2017  }
2018  }
2019  return success();
2020 }
2021 
2022 static Operation *
2023 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2024  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2025  assert(rootOp && "expected valid operation");
2026 
2027  Dialect *ompDialect = rootOp->getDialect();
2028  Operation *capturedOp = nullptr;
2029  DominanceInfo domInfo;
2030 
2031  // Process in pre-order to check operations from outermost to innermost,
2032  // ensuring we only enter the region of an operation if it meets the criteria
2033  // for being captured. We stop the exploration of nested operations as soon as
2034  // we process a region holding no operations to be captured.
2035  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2036  if (op == rootOp)
2037  return WalkResult::advance();
2038 
2039  // Ignore operations of other dialects or omp operations with no regions,
2040  // because these will only be checked if they are siblings of an omp
2041  // operation that can potentially be captured.
2042  bool isOmpDialect = op->getDialect() == ompDialect;
2043  bool hasRegions = op->getNumRegions() > 0;
2044  if (!isOmpDialect || !hasRegions)
2045  return WalkResult::skip();
2046 
2047  // This operation cannot be captured if it can be executed more than once
2048  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2049  // be executed before all exits of the region (i.e. it doesn't dominate all
2050  // blocks with no successors reachable from the entry block).
2051  if (checkSingleMandatoryExec) {
2052  Region *parentRegion = op->getParentRegion();
2053  Block *parentBlock = op->getBlock();
2054 
2055  for (Block *successor : parentBlock->getSuccessors())
2056  if (successor->isReachable(parentBlock))
2057  return WalkResult::interrupt();
2058 
2059  for (Block &block : *parentRegion)
2060  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2061  !domInfo.dominates(parentBlock, &block))
2062  return WalkResult::interrupt();
2063  }
2064 
2065  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2066  // into nested operations.
2067  for (Operation &sibling : op->getParentRegion()->getOps())
2068  if (&sibling != op && !siblingAllowedFn(&sibling))
2069  return WalkResult::interrupt();
2070 
2071  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2072  // Otherwise, process the contents of this operation.
2073  capturedOp = op;
2074  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2075  : WalkResult::advance();
2076  });
2077 
2078  return capturedOp;
2079 }
2080 
2081 Operation *TargetOp::getInnermostCapturedOmpOp() {
2082  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2083 
2084  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2085  // effects, but don't include a memory write effect.
2086  return findCapturedOmpOp(
2087  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2088  if (!sibling)
2089  return false;
2090 
2091  if (ompDialect == sibling->getDialect())
2092  return sibling->hasTrait<OpTrait::IsTerminator>();
2093 
2094  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2096  effects;
2097  memOp.getEffects(effects);
2098  return !llvm::any_of(
2099  effects, [&](MemoryEffects::EffectInstance &effect) {
2100  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2101  isa<SideEffects::AutomaticAllocationScopeResource>(
2102  effect.getResource());
2103  });
2104  }
2105  return true;
2106  });
2107 }
2108 
2109 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2110  // A non-null captured op is only valid if it resides inside of a TargetOp
2111  // and is the result of calling getInnermostCapturedOmpOp() on it.
2112  TargetOp targetOp =
2113  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2114  assert((!capturedOp ||
2115  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2116  "unexpected captured op");
2117 
2118  // If it's not capturing a loop, it's a default target region.
2119  if (!isa_and_present<LoopNestOp>(capturedOp))
2120  return TargetRegionFlags::generic;
2121 
2122  // Get the innermost non-simd loop wrapper.
2123  SmallVector<LoopWrapperInterface> loopWrappers;
2124  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2125  assert(!loopWrappers.empty());
2126 
2127  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2128  if (isa<SimdOp>(innermostWrapper))
2129  innermostWrapper = std::next(innermostWrapper);
2130 
2131  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2132  if (numWrappers != 1 && numWrappers != 2)
2133  return TargetRegionFlags::generic;
2134 
2135  // Detect target-teams-distribute-parallel-wsloop[-simd].
2136  if (numWrappers == 2) {
2137  if (!isa<WsloopOp>(innermostWrapper))
2138  return TargetRegionFlags::generic;
2139 
2140  innermostWrapper = std::next(innermostWrapper);
2141  if (!isa<DistributeOp>(innermostWrapper))
2142  return TargetRegionFlags::generic;
2143 
2144  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2145  if (!isa_and_present<ParallelOp>(parallelOp))
2146  return TargetRegionFlags::generic;
2147 
2148  Operation *teamsOp = parallelOp->getParentOp();
2149  if (!isa_and_present<TeamsOp>(teamsOp))
2150  return TargetRegionFlags::generic;
2151 
2152  if (teamsOp->getParentOp() == targetOp.getOperation())
2153  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2154  }
2155  // Detect target-teams-distribute[-simd] and target-teams-loop.
2156  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2157  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2158  if (!isa_and_present<TeamsOp>(teamsOp))
2159  return TargetRegionFlags::generic;
2160 
2161  if (teamsOp->getParentOp() != targetOp.getOperation())
2162  return TargetRegionFlags::generic;
2163 
2164  if (isa<LoopOp>(innermostWrapper))
2165  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2166 
2167  // Find single immediately nested captured omp.parallel and add spmd flag
2168  // (generic-spmd case).
2169  //
2170  // TODO: This shouldn't have to be done here, as it is too easy to break.
2171  // The openmp-opt pass should be updated to be able to promote kernels like
2172  // this from "Generic" to "Generic-SPMD". However, the use of the
2173  // `kmpc_distribute_static_loop` family of functions produced by the
2174  // OMPIRBuilder for these kernels prevents that from working.
2175  Dialect *ompDialect = targetOp->getDialect();
2176  Operation *nestedCapture = findCapturedOmpOp(
2177  capturedOp, /*checkSingleMandatoryExec=*/false,
2178  [&](Operation *sibling) {
2179  return sibling && (ompDialect != sibling->getDialect() ||
2180  sibling->hasTrait<OpTrait::IsTerminator>());
2181  });
2182 
2183  TargetRegionFlags result =
2184  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2185 
2186  if (!nestedCapture)
2187  return result;
2188 
2189  while (nestedCapture->getParentOp() != capturedOp)
2190  nestedCapture = nestedCapture->getParentOp();
2191 
2192  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2193  : result;
2194  }
2195  // Detect target-parallel-wsloop[-simd].
2196  else if (isa<WsloopOp>(innermostWrapper)) {
2197  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2198  if (!isa_and_present<ParallelOp>(parallelOp))
2199  return TargetRegionFlags::generic;
2200 
2201  if (parallelOp->getParentOp() == targetOp.getOperation())
2202  return TargetRegionFlags::spmd;
2203  }
2204 
2205  return TargetRegionFlags::generic;
2206 }
2207 
2208 //===----------------------------------------------------------------------===//
2209 // ParallelOp
2210 //===----------------------------------------------------------------------===//
2211 
2212 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2213  ArrayRef<NamedAttribute> attributes) {
2214  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2215  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2216  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2217  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2218  /*proc_bind_kind=*/nullptr,
2219  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2220  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2221  state.addAttributes(attributes);
2222 }
2223 
2224 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2225  const ParallelOperands &clauses) {
2226  MLIRContext *ctx = builder.getContext();
2227  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2228  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2229  makeArrayAttr(ctx, clauses.privateSyms),
2230  clauses.privateNeedsBarrier, clauses.procBindKind,
2231  clauses.reductionMod, clauses.reductionVars,
2232  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2233  makeArrayAttr(ctx, clauses.reductionSyms));
2234 }
2235 
2236 template <typename OpType>
2237 static LogicalResult verifyPrivateVarList(OpType &op) {
2238  auto privateVars = op.getPrivateVars();
2239  auto privateSyms = op.getPrivateSymsAttr();
2240 
2241  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2242  return success();
2243 
2244  auto numPrivateVars = privateVars.size();
2245  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2246 
2247  if (numPrivateVars != numPrivateSyms)
2248  return op.emitError() << "inconsistent number of private variables and "
2249  "privatizer op symbols, private vars: "
2250  << numPrivateVars
2251  << " vs. privatizer op symbols: " << numPrivateSyms;
2252 
2253  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2254  Type varType = std::get<0>(privateVarInfo).getType();
2255  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2256  PrivateClauseOp privatizerOp =
2257  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2258 
2259  if (privatizerOp == nullptr)
2260  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2261  << privateSym << "'";
2262 
2263  Type privatizerType = privatizerOp.getArgType();
2264 
2265  if (privatizerType && (varType != privatizerType))
2266  return op.emitError()
2267  << "type mismatch between a "
2268  << (privatizerOp.getDataSharingType() ==
2269  DataSharingClauseType::Private
2270  ? "private"
2271  : "firstprivate")
2272  << " variable and its privatizer op, var type: " << varType
2273  << " vs. privatizer op type: " << privatizerType;
2274  }
2275 
2276  return success();
2277 }
2278 
2279 LogicalResult ParallelOp::verify() {
2280  if (getAllocateVars().size() != getAllocatorVars().size())
2281  return emitError(
2282  "expected equal sizes for allocate and allocator variables");
2283 
2284  if (failed(verifyPrivateVarList(*this)))
2285  return failure();
2286 
2287  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2288  getReductionByref());
2289 }
2290 
2291 LogicalResult ParallelOp::verifyRegions() {
2292  auto distChildOps = getOps<DistributeOp>();
2293  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2294  if (numDistChildOps > 1)
2295  return emitError()
2296  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2297 
2298  if (numDistChildOps == 1) {
2299  if (!isComposite())
2300  return emitError()
2301  << "'omp.composite' attribute missing from composite operation";
2302 
2303  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2304  Operation &distributeOp = **distChildOps.begin();
2305  for (Operation &childOp : getOps()) {
2306  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2307  continue;
2308 
2309  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2310  return emitError() << "unexpected OpenMP operation inside of composite "
2311  "'omp.parallel': "
2312  << childOp.getName();
2313  }
2314  } else if (isComposite()) {
2315  return emitError()
2316  << "'omp.composite' attribute present in non-composite operation";
2317  }
2318  return success();
2319 }
2320 
2321 //===----------------------------------------------------------------------===//
2322 // TeamsOp
2323 //===----------------------------------------------------------------------===//
2324 
2326  while ((op = op->getParentOp()))
2327  if (isa<OpenMPDialect>(op->getDialect()))
2328  return false;
2329  return true;
2330 }
2331 
2332 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2333  const TeamsOperands &clauses) {
2334  MLIRContext *ctx = builder.getContext();
2335  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2336  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2337  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2338  /*private_vars=*/{}, /*private_syms=*/nullptr,
2339  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2340  clauses.reductionVars,
2341  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2342  makeArrayAttr(ctx, clauses.reductionSyms),
2343  clauses.threadLimit);
2344 }
2345 
2346 LogicalResult TeamsOp::verify() {
2347  // Check parent region
2348  // TODO If nested inside of a target region, also check that it does not
2349  // contain any statements, declarations or directives other than this
2350  // omp.teams construct. The issue is how to support the initialization of
2351  // this operation's own arguments (allow SSA values across omp.target?).
2352  Operation *op = getOperation();
2353  if (!isa<TargetOp>(op->getParentOp()) &&
2355  return emitError("expected to be nested inside of omp.target or not nested "
2356  "in any OpenMP dialect operations");
2357 
2358  // Check for num_teams clause restrictions
2359  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2360  auto numTeamsUpperBound = getNumTeamsUpper();
2361  if (!numTeamsUpperBound)
2362  return emitError("expected num_teams upper bound to be defined if the "
2363  "lower bound is defined");
2364  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2365  return emitError(
2366  "expected num_teams upper bound and lower bound to be the same type");
2367  }
2368 
2369  // Check for allocate clause restrictions
2370  if (getAllocateVars().size() != getAllocatorVars().size())
2371  return emitError(
2372  "expected equal sizes for allocate and allocator variables");
2373 
2374  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2375  getReductionByref());
2376 }
2377 
2378 //===----------------------------------------------------------------------===//
2379 // SectionOp
2380 //===----------------------------------------------------------------------===//
2381 
2382 OperandRange SectionOp::getPrivateVars() {
2383  return getParentOp().getPrivateVars();
2384 }
2385 
2386 OperandRange SectionOp::getReductionVars() {
2387  return getParentOp().getReductionVars();
2388 }
2389 
2390 //===----------------------------------------------------------------------===//
2391 // SectionsOp
2392 //===----------------------------------------------------------------------===//
2393 
2394 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2395  const SectionsOperands &clauses) {
2396  MLIRContext *ctx = builder.getContext();
2397  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2398  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2399  clauses.nowait, /*private_vars=*/{},
2400  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2401  clauses.reductionMod, clauses.reductionVars,
2402  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2403  makeArrayAttr(ctx, clauses.reductionSyms));
2404 }
2405 
2406 LogicalResult SectionsOp::verify() {
2407  if (getAllocateVars().size() != getAllocatorVars().size())
2408  return emitError(
2409  "expected equal sizes for allocate and allocator variables");
2410 
2411  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2412  getReductionByref());
2413 }
2414 
2415 LogicalResult SectionsOp::verifyRegions() {
2416  for (auto &inst : *getRegion().begin()) {
2417  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2418  return emitOpError()
2419  << "expected omp.section op or terminator op inside region";
2420  }
2421  }
2422 
2423  return success();
2424 }
2425 
2426 //===----------------------------------------------------------------------===//
2427 // SingleOp
2428 //===----------------------------------------------------------------------===//
2429 
2430 void SingleOp::build(OpBuilder &builder, OperationState &state,
2431  const SingleOperands &clauses) {
2432  MLIRContext *ctx = builder.getContext();
2433  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2434  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2435  clauses.copyprivateVars,
2436  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2437  /*private_vars=*/{}, /*private_syms=*/nullptr,
2438  /*private_needs_barrier=*/nullptr);
2439 }
2440 
2441 LogicalResult SingleOp::verify() {
2442  // Check for allocate clause restrictions
2443  if (getAllocateVars().size() != getAllocatorVars().size())
2444  return emitError(
2445  "expected equal sizes for allocate and allocator variables");
2446 
2447  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2448  getCopyprivateSyms());
2449 }
2450 
2451 //===----------------------------------------------------------------------===//
2452 // WorkshareOp
2453 //===----------------------------------------------------------------------===//
2454 
2455 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2456  const WorkshareOperands &clauses) {
2457  WorkshareOp::build(builder, state, clauses.nowait);
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // WorkshareLoopWrapperOp
2462 //===----------------------------------------------------------------------===//
2463 
2464 LogicalResult WorkshareLoopWrapperOp::verify() {
2465  if (!(*this)->getParentOfType<WorkshareOp>())
2466  return emitOpError() << "must be nested in an omp.workshare";
2467  return success();
2468 }
2469 
2470 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2471  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2472  getNestedWrapper())
2473  return emitOpError() << "expected to be a standalone loop wrapper";
2474 
2475  return success();
2476 }
2477 
2478 //===----------------------------------------------------------------------===//
2479 // LoopWrapperInterface
2480 //===----------------------------------------------------------------------===//
2481 
2482 LogicalResult LoopWrapperInterface::verifyImpl() {
2483  Operation *op = this->getOperation();
2484  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2486  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2487  "and `SingleBlock` traits";
2488 
2489  if (op->getNumRegions() != 1)
2490  return emitOpError() << "loop wrapper does not contain exactly one region";
2491 
2492  Region &region = op->getRegion(0);
2493  if (range_size(region.getOps()) != 1)
2494  return emitOpError()
2495  << "loop wrapper does not contain exactly one nested op";
2496 
2497  Operation &firstOp = *region.op_begin();
2498  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2499  return emitOpError() << "nested in loop wrapper is not another loop "
2500  "wrapper or `omp.loop_nest`";
2501 
2502  return success();
2503 }
2504 
2505 //===----------------------------------------------------------------------===//
2506 // LoopOp
2507 //===----------------------------------------------------------------------===//
2508 
2509 void LoopOp::build(OpBuilder &builder, OperationState &state,
2510  const LoopOperands &clauses) {
2511  MLIRContext *ctx = builder.getContext();
2512 
2513  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2514  makeArrayAttr(ctx, clauses.privateSyms),
2515  clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2516  clauses.reductionMod, clauses.reductionVars,
2517  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2518  makeArrayAttr(ctx, clauses.reductionSyms));
2519 }
2520 
2521 LogicalResult LoopOp::verify() {
2522  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2523  getReductionByref());
2524 }
2525 
2526 LogicalResult LoopOp::verifyRegions() {
2527  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2528  getNestedWrapper())
2529  return emitOpError() << "expected to be a standalone loop wrapper";
2530 
2531  return success();
2532 }
2533 
2534 //===----------------------------------------------------------------------===//
2535 // WsloopOp
2536 //===----------------------------------------------------------------------===//
2537 
2538 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2539  ArrayRef<NamedAttribute> attributes) {
2540  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2541  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2542  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2543  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2544  /*private_needs_barrier=*/false,
2545  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2546  /*reduction_byref=*/nullptr,
2547  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2548  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2549  /*schedule_simd=*/false);
2550  state.addAttributes(attributes);
2551 }
2552 
2553 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2554  const WsloopOperands &clauses) {
2555  MLIRContext *ctx = builder.getContext();
2556  // TODO: Store clauses in op: allocateVars, allocatorVars
2557  WsloopOp::build(
2558  builder, state,
2559  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2560  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2561  clauses.ordered, clauses.privateVars,
2562  makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2563  clauses.reductionMod, clauses.reductionVars,
2564  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2565  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2566  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2567 }
2568 
2569 LogicalResult WsloopOp::verify() {
2570  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2571  getReductionByref());
2572 }
2573 
2574 LogicalResult WsloopOp::verifyRegions() {
2575  bool isCompositeChildLeaf =
2576  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2577 
2578  if (LoopWrapperInterface nested = getNestedWrapper()) {
2579  if (!isComposite())
2580  return emitError()
2581  << "'omp.composite' attribute missing from composite wrapper";
2582 
2583  // Check for the allowed leaf constructs that may appear in a composite
2584  // construct directly after DO/FOR.
2585  if (!isa<SimdOp>(nested))
2586  return emitError() << "only supported nested wrapper is 'omp.simd'";
2587 
2588  } else if (isComposite() && !isCompositeChildLeaf) {
2589  return emitError()
2590  << "'omp.composite' attribute present in non-composite wrapper";
2591  } else if (!isComposite() && isCompositeChildLeaf) {
2592  return emitError()
2593  << "'omp.composite' attribute missing from composite wrapper";
2594  }
2595 
2596  return success();
2597 }
2598 
2599 //===----------------------------------------------------------------------===//
2600 // Simd construct [2.9.3.1]
2601 //===----------------------------------------------------------------------===//
2602 
2603 void SimdOp::build(OpBuilder &builder, OperationState &state,
2604  const SimdOperands &clauses) {
2605  MLIRContext *ctx = builder.getContext();
2606  // TODO Store clauses in op: linearVars, linearStepVars
2607  SimdOp::build(builder, state, clauses.alignedVars,
2608  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2609  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2610  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2611  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2612  clauses.privateNeedsBarrier, clauses.reductionMod,
2613  clauses.reductionVars,
2614  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2615  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2616  clauses.simdlen);
2617 }
2618 
2619 LogicalResult SimdOp::verify() {
2620  if (getSimdlen().has_value() && getSafelen().has_value() &&
2621  getSimdlen().value() > getSafelen().value())
2622  return emitOpError()
2623  << "simdlen clause and safelen clause are both present, but the "
2624  "simdlen value is not less than or equal to safelen value";
2625 
2626  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2627  return failure();
2628 
2629  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2630  return failure();
2631 
2632  bool isCompositeChildLeaf =
2633  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2634 
2635  if (!isComposite() && isCompositeChildLeaf)
2636  return emitError()
2637  << "'omp.composite' attribute missing from composite wrapper";
2638 
2639  if (isComposite() && !isCompositeChildLeaf)
2640  return emitError()
2641  << "'omp.composite' attribute present in non-composite wrapper";
2642 
2643  // Firstprivate is not allowed for SIMD in the standard. Check that none of
2644  // the private decls are for firstprivate.
2645  std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2646  if (privateSyms) {
2647  for (const Attribute &sym : *privateSyms) {
2648  auto symRef = cast<SymbolRefAttr>(sym);
2649  omp::PrivateClauseOp privatizer =
2650  SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2651  getOperation(), symRef);
2652  if (!privatizer)
2653  return emitError() << "Cannot find privatizer '" << symRef << "'";
2654  if (privatizer.getDataSharingType() ==
2655  DataSharingClauseType::FirstPrivate)
2656  return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2657  }
2658  }
2659 
2660  return success();
2661 }
2662 
2663 LogicalResult SimdOp::verifyRegions() {
2664  if (getNestedWrapper())
2665  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2666 
2667  return success();
2668 }
2669 
2670 //===----------------------------------------------------------------------===//
2671 // Distribute construct [2.9.4.1]
2672 //===----------------------------------------------------------------------===//
2673 
2674 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2675  const DistributeOperands &clauses) {
2676  DistributeOp::build(builder, state, clauses.allocateVars,
2677  clauses.allocatorVars, clauses.distScheduleStatic,
2678  clauses.distScheduleChunkSize, clauses.order,
2679  clauses.orderMod, clauses.privateVars,
2680  makeArrayAttr(builder.getContext(), clauses.privateSyms),
2681  clauses.privateNeedsBarrier);
2682 }
2683 
2684 LogicalResult DistributeOp::verify() {
2685  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2686  return emitOpError() << "chunk size set without "
2687  "dist_schedule_static being present";
2688 
2689  if (getAllocateVars().size() != getAllocatorVars().size())
2690  return emitError(
2691  "expected equal sizes for allocate and allocator variables");
2692 
2693  return success();
2694 }
2695 
2696 LogicalResult DistributeOp::verifyRegions() {
2697  if (LoopWrapperInterface nested = getNestedWrapper()) {
2698  if (!isComposite())
2699  return emitError()
2700  << "'omp.composite' attribute missing from composite wrapper";
2701  // Check for the allowed leaf constructs that may appear in a composite
2702  // construct directly after DISTRIBUTE.
2703  if (isa<WsloopOp>(nested)) {
2704  Operation *parentOp = (*this)->getParentOp();
2705  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2706  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2707  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2708  "when a composite 'omp.parallel' is the direct "
2709  "parent";
2710  }
2711  } else if (!isa<SimdOp>(nested))
2712  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2713  "'omp.wsloop'";
2714  } else if (isComposite()) {
2715  return emitError()
2716  << "'omp.composite' attribute present in non-composite wrapper";
2717  }
2718 
2719  return success();
2720 }
2721 
2722 //===----------------------------------------------------------------------===//
2723 // DeclareMapperOp / DeclareMapperInfoOp
2724 //===----------------------------------------------------------------------===//
2725 
2726 LogicalResult DeclareMapperInfoOp::verify() {
2727  return verifyMapClause(*this, getMapVars());
2728 }
2729 
2730 LogicalResult DeclareMapperOp::verifyRegions() {
2731  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2732  getRegion().getBlocks().front().getTerminator()))
2733  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2734 
2735  return success();
2736 }
2737 
2738 //===----------------------------------------------------------------------===//
2739 // DeclareReductionOp
2740 //===----------------------------------------------------------------------===//
2741 
2742 LogicalResult DeclareReductionOp::verifyRegions() {
2743  if (!getAllocRegion().empty()) {
2744  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2745  if (yieldOp.getResults().size() != 1 ||
2746  yieldOp.getResults().getTypes()[0] != getType())
2747  return emitOpError() << "expects alloc region to yield a value "
2748  "of the reduction type";
2749  }
2750  }
2751 
2752  if (getInitializerRegion().empty())
2753  return emitOpError() << "expects non-empty initializer region";
2754  Block &initializerEntryBlock = getInitializerRegion().front();
2755 
2756  if (initializerEntryBlock.getNumArguments() == 1) {
2757  if (!getAllocRegion().empty())
2758  return emitOpError() << "expects two arguments to the initializer region "
2759  "when an allocation region is used";
2760  } else if (initializerEntryBlock.getNumArguments() == 2) {
2761  if (getAllocRegion().empty())
2762  return emitOpError() << "expects one argument to the initializer region "
2763  "when no allocation region is used";
2764  } else {
2765  return emitOpError()
2766  << "expects one or two arguments to the initializer region";
2767  }
2768 
2769  for (mlir::Value arg : initializerEntryBlock.getArguments())
2770  if (arg.getType() != getType())
2771  return emitOpError() << "expects initializer region argument to match "
2772  "the reduction type";
2773 
2774  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2775  if (yieldOp.getResults().size() != 1 ||
2776  yieldOp.getResults().getTypes()[0] != getType())
2777  return emitOpError() << "expects initializer region to yield a value "
2778  "of the reduction type";
2779  }
2780 
2781  if (getReductionRegion().empty())
2782  return emitOpError() << "expects non-empty reduction region";
2783  Block &reductionEntryBlock = getReductionRegion().front();
2784  if (reductionEntryBlock.getNumArguments() != 2 ||
2785  reductionEntryBlock.getArgumentTypes()[0] !=
2786  reductionEntryBlock.getArgumentTypes()[1] ||
2787  reductionEntryBlock.getArgumentTypes()[0] != getType())
2788  return emitOpError() << "expects reduction region with two arguments of "
2789  "the reduction type";
2790  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2791  if (yieldOp.getResults().size() != 1 ||
2792  yieldOp.getResults().getTypes()[0] != getType())
2793  return emitOpError() << "expects reduction region to yield a value "
2794  "of the reduction type";
2795  }
2796 
2797  if (!getAtomicReductionRegion().empty()) {
2798  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2799  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2800  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2801  atomicReductionEntryBlock.getArgumentTypes()[1])
2802  return emitOpError() << "expects atomic reduction region with two "
2803  "arguments of the same type";
2804  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2805  atomicReductionEntryBlock.getArgumentTypes()[0]);
2806  if (!ptrType ||
2807  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2808  return emitOpError() << "expects atomic reduction region arguments to "
2809  "be accumulators containing the reduction type";
2810  }
2811 
2812  if (getCleanupRegion().empty())
2813  return success();
2814  Block &cleanupEntryBlock = getCleanupRegion().front();
2815  if (cleanupEntryBlock.getNumArguments() != 1 ||
2816  cleanupEntryBlock.getArgument(0).getType() != getType())
2817  return emitOpError() << "expects cleanup region with one argument "
2818  "of the reduction type";
2819 
2820  return success();
2821 }
2822 
2823 //===----------------------------------------------------------------------===//
2824 // TaskOp
2825 //===----------------------------------------------------------------------===//
2826 
2827 void TaskOp::build(OpBuilder &builder, OperationState &state,
2828  const TaskOperands &clauses) {
2829  MLIRContext *ctx = builder.getContext();
2830  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2831  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2832  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2833  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2834  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2835  clauses.priority, /*private_vars=*/clauses.privateVars,
2836  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2837  clauses.privateNeedsBarrier, clauses.untied,
2838  clauses.eventHandle);
2839 }
2840 
2841 LogicalResult TaskOp::verify() {
2842  LogicalResult verifyDependVars =
2843  verifyDependVarList(*this, getDependKinds(), getDependVars());
2844  return failed(verifyDependVars)
2845  ? verifyDependVars
2846  : verifyReductionVarList(*this, getInReductionSyms(),
2847  getInReductionVars(),
2848  getInReductionByref());
2849 }
2850 
2851 //===----------------------------------------------------------------------===//
2852 // TaskgroupOp
2853 //===----------------------------------------------------------------------===//
2854 
2855 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2856  const TaskgroupOperands &clauses) {
2857  MLIRContext *ctx = builder.getContext();
2858  TaskgroupOp::build(builder, state, clauses.allocateVars,
2859  clauses.allocatorVars, clauses.taskReductionVars,
2860  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2861  makeArrayAttr(ctx, clauses.taskReductionSyms));
2862 }
2863 
2864 LogicalResult TaskgroupOp::verify() {
2865  return verifyReductionVarList(*this, getTaskReductionSyms(),
2866  getTaskReductionVars(),
2867  getTaskReductionByref());
2868 }
2869 
2870 //===----------------------------------------------------------------------===//
2871 // TaskloopOp
2872 //===----------------------------------------------------------------------===//
2873 
2874 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2875  const TaskloopOperands &clauses) {
2876  MLIRContext *ctx = builder.getContext();
2877  TaskloopOp::build(
2878  builder, state, clauses.allocateVars, clauses.allocatorVars,
2879  clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2880  clauses.inReductionVars,
2881  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2882  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2883  clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2884  /*private_vars=*/clauses.privateVars,
2885  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2886  clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2887  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2888  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2889 }
2890 
2891 LogicalResult TaskloopOp::verify() {
2892  if (getAllocateVars().size() != getAllocatorVars().size())
2893  return emitError(
2894  "expected equal sizes for allocate and allocator variables");
2895  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2896  getReductionVars(), getReductionByref())) ||
2897  failed(verifyReductionVarList(*this, getInReductionSyms(),
2898  getInReductionVars(),
2899  getInReductionByref())))
2900  return failure();
2901 
2902  if (!getReductionVars().empty() && getNogroup())
2903  return emitError("if a reduction clause is present on the taskloop "
2904  "directive, the nogroup clause must not be specified");
2905  for (auto var : getReductionVars()) {
2906  if (llvm::is_contained(getInReductionVars(), var))
2907  return emitError("the same list item cannot appear in both a reduction "
2908  "and an in_reduction clause");
2909  }
2910 
2911  if (getGrainsize() && getNumTasks()) {
2912  return emitError(
2913  "the grainsize clause and num_tasks clause are mutually exclusive and "
2914  "may not appear on the same taskloop directive");
2915  }
2916 
2917  return success();
2918 }
2919 
2920 LogicalResult TaskloopOp::verifyRegions() {
2921  if (LoopWrapperInterface nested = getNestedWrapper()) {
2922  if (!isComposite())
2923  return emitError()
2924  << "'omp.composite' attribute missing from composite wrapper";
2925 
2926  // Check for the allowed leaf constructs that may appear in a composite
2927  // construct directly after TASKLOOP.
2928  if (!isa<SimdOp>(nested))
2929  return emitError() << "only supported nested wrapper is 'omp.simd'";
2930  } else if (isComposite()) {
2931  return emitError()
2932  << "'omp.composite' attribute present in non-composite wrapper";
2933  }
2934 
2935  return success();
2936 }
2937 
2938 //===----------------------------------------------------------------------===//
2939 // LoopNestOp
2940 //===----------------------------------------------------------------------===//
2941 
2942 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2943  // Parse an opening `(` followed by induction variables followed by `)`
2946  Type loopVarType;
2947  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2948  parser.parseColonType(loopVarType) ||
2949  // Parse loop bounds.
2950  parser.parseEqual() ||
2951  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2952  parser.parseKeyword("to") ||
2953  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2954  return failure();
2955 
2956  for (auto &iv : ivs)
2957  iv.type = loopVarType;
2958 
2959  // Parse "inclusive" flag.
2960  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2961  result.addAttribute("loop_inclusive",
2962  UnitAttr::get(parser.getBuilder().getContext()));
2963 
2964  // Parse step values.
2966  if (parser.parseKeyword("step") ||
2967  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2968  return failure();
2969 
2970  // Parse the body.
2971  Region *region = result.addRegion();
2972  if (parser.parseRegion(*region, ivs))
2973  return failure();
2974 
2975  // Resolve operands.
2976  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2977  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2978  parser.resolveOperands(steps, loopVarType, result.operands))
2979  return failure();
2980 
2981  // Parse the optional attribute list.
2982  return parser.parseOptionalAttrDict(result.attributes);
2983 }
2984 
2986  Region &region = getRegion();
2987  auto args = region.getArguments();
2988  p << " (" << args << ") : " << args[0].getType() << " = ("
2989  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2990  if (getLoopInclusive())
2991  p << "inclusive ";
2992  p << "step (" << getLoopSteps() << ") ";
2993  p.printRegion(region, /*printEntryBlockArgs=*/false);
2994 }
2995 
2996 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2997  const LoopNestOperands &clauses) {
2998  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2999  clauses.loopUpperBounds, clauses.loopSteps,
3000  clauses.loopInclusive);
3001 }
3002 
3003 LogicalResult LoopNestOp::verify() {
3004  if (getLoopLowerBounds().empty())
3005  return emitOpError() << "must represent at least one loop";
3006 
3007  if (getLoopLowerBounds().size() != getIVs().size())
3008  return emitOpError() << "number of range arguments and IVs do not match";
3009 
3010  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3011  if (lb.getType() != iv.getType())
3012  return emitOpError()
3013  << "range argument type does not match corresponding IV type";
3014  }
3015 
3016  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3017  return emitOpError() << "expects parent op to be a loop wrapper";
3018 
3019  return success();
3020 }
3021 
3022 void LoopNestOp::gatherWrappers(
3024  Operation *parent = (*this)->getParentOp();
3025  while (auto wrapper =
3026  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3027  wrappers.push_back(wrapper);
3028  parent = parent->getParentOp();
3029  }
3030 }
3031 
3032 //===----------------------------------------------------------------------===//
3033 // OpenMP canonical loop handling
3034 //===----------------------------------------------------------------------===//
3035 
3036 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3038 
3039  // Defining a CLI for a generated loop is optional; if there is none then
3040  // there is no followup-tranformation
3041  if (!cli)
3042  return {{}, nullptr, nullptr};
3043 
3044  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3045  "Unexpected type of cli");
3046 
3047  NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3048  OpOperand *gen = nullptr;
3049  OpOperand *cons = nullptr;
3050  for (OpOperand &use : cli.getUses()) {
3051  auto op = cast<LoopTransformationInterface>(use.getOwner());
3052 
3053  unsigned opnum = use.getOperandNumber();
3054  if (op.isGeneratee(opnum)) {
3055  assert(!gen && "Each CLI may have at most one def");
3056  gen = &use;
3057  } else if (op.isApplyee(opnum)) {
3058  assert(!cons && "Each CLI may have at most one consumer");
3059  cons = &use;
3060  } else {
3061  llvm_unreachable("Unexpected operand for a CLI");
3062  }
3063  }
3064 
3065  return {create, gen, cons};
3066 }
3067 
3068 void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3069  ::mlir::OperationState &odsState) {
3070  odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3071 }
3072 
3073 void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3074  Value result = getResult();
3075  auto [newCli, gen, cons] = decodeCli(result);
3076 
3077  // Derive the CLI variable name from its generator:
3078  // * "canonloop" for omp.canonical_loop
3079  // * custom name for loop transformation generatees
3080  // * "cli" as fallback if no generator
3081  // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3082  // at that level
3083  // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3084  // the index of that region
3085  std::string cliName{"cli"};
3086  if (gen) {
3087  cliName =
3088  TypeSwitch<Operation *, std::string>(gen->getOwner())
3089  .Case([&](CanonicalLoopOp op) {
3090  // Find the canonical loop nesting: For each ancestor add a
3091  // "+_r<idx>" suffix (in reverse order)
3092  SmallVector<std::string> components;
3093  Operation *o = op.getOperation();
3094  while (o) {
3096  break;
3097 
3098  Region *r = o->getParentRegion();
3099  if (!r)
3100  break;
3101 
3102  auto getSequentialIndex = [](Region *r, Operation *o) {
3103  llvm::ReversePostOrderTraversal<Block *> traversal(
3104  &r->getBlocks().front());
3105  size_t idx = 0;
3106  for (Block *b : traversal) {
3107  for (Operation &op : *b) {
3108  if (&op == o)
3109  return idx;
3110  // Only consider operations that are containers as
3111  // possible children
3112  if (!op.getRegions().empty())
3113  idx += 1;
3114  }
3115  }
3116  llvm_unreachable("Operation not part of the region");
3117  };
3118  size_t sequentialIdx = getSequentialIndex(r, o);
3119  components.push_back(("s" + Twine(sequentialIdx)).str());
3120 
3121  Operation *parent = r->getParentOp();
3122  if (!parent)
3123  break;
3124 
3125  // If the operation has more than one region, also count in
3126  // which of the regions
3127  if (parent->getRegions().size() > 1) {
3128  auto getRegionIndex = [](Operation *o, Region *r) {
3129  for (auto [idx, region] :
3130  llvm::enumerate(o->getRegions())) {
3131  if (&region == r)
3132  return idx;
3133  }
3134  llvm_unreachable("Region not child its parent operation");
3135  };
3136  size_t regionIdx = getRegionIndex(parent, r);
3137  components.push_back(("r" + Twine(regionIdx)).str());
3138  }
3139 
3140  // next parent
3141  o = parent;
3142  }
3143 
3144  SmallString<64> Name("canonloop");
3145  for (std::string s : reverse(components)) {
3146  Name += '_';
3147  Name += s;
3148  }
3149 
3150  return Name;
3151  })
3152  .Case([&](UnrollHeuristicOp op) -> std::string {
3153  llvm_unreachable("heuristic unrolling does not generate a loop");
3154  })
3155  .Default([&](Operation *op) {
3156  assert(false && "TODO: Custom name for this operation");
3157  return "transformed";
3158  });
3159  }
3160 
3161  setNameFn(result, cliName);
3162 }
3163 
3164 LogicalResult NewCliOp::verify() {
3165  Value cli = getResult();
3166 
3167  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3168  "Unexpected type of cli");
3169 
3170  // Check that the CLI is used in at most generator and one consumer
3171  OpOperand *gen = nullptr;
3172  OpOperand *cons = nullptr;
3173  for (mlir::OpOperand &use : cli.getUses()) {
3174  auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3175 
3176  unsigned opnum = use.getOperandNumber();
3177  if (op.isGeneratee(opnum)) {
3178  if (gen) {
3179  InFlightDiagnostic error =
3180  emitOpError("CLI must have at most one generator");
3181  error.attachNote(gen->getOwner()->getLoc())
3182  .append("first generator here:");
3183  error.attachNote(use.getOwner()->getLoc())
3184  .append("second generator here:");
3185  return error;
3186  }
3187 
3188  gen = &use;
3189  } else if (op.isApplyee(opnum)) {
3190  if (cons) {
3191  InFlightDiagnostic error =
3192  emitOpError("CLI must have at most one consumer");
3193  error.attachNote(cons->getOwner()->getLoc())
3194  .append("first consumer here:")
3195  .appendOp(*cons->getOwner(),
3196  OpPrintingFlags().printGenericOpForm());
3197  error.attachNote(use.getOwner()->getLoc())
3198  .append("second consumer here:")
3199  .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3200  return error;
3201  }
3202 
3203  cons = &use;
3204  } else {
3205  llvm_unreachable("Unexpected operand for a CLI");
3206  }
3207  }
3208 
3209  // If the CLI is source of a transformation, it must have a generator
3210  if (cons && !gen) {
3211  InFlightDiagnostic error = emitOpError("CLI has no generator");
3212  error.attachNote(cons->getOwner()->getLoc())
3213  .append("see consumer here: ")
3214  .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3215  return error;
3216  }
3217 
3218  return success();
3219 }
3220 
3221 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3222  Value tripCount) {
3223  odsState.addOperands(tripCount);
3224  odsState.addOperands(Value());
3225  (void)odsState.addRegion();
3226 }
3227 
3228 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3229  Value tripCount, ::mlir::Value cli) {
3230  odsState.addOperands(tripCount);
3231  odsState.addOperands(cli);
3232  (void)odsState.addRegion();
3233 }
3234 
3235 void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3236  setNameFn(&getRegion().front(), "body_entry");
3237 }
3238 
3239 void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3240  OpAsmSetValueNameFn setNameFn) {
3241  setNameFn(region.getArgument(0), "iv");
3242 }
3243 
3245  if (getCli())
3246  p << '(' << getCli() << ')';
3247  p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3248  << " in range(" << getTripCount() << ") ";
3249 
3250  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3251  /*printBlockTerminators=*/true);
3252 
3253  p.printOptionalAttrDict((*this)->getAttrs());
3254 }
3255 
3256 mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3257  ::mlir::OperationState &result) {
3258  CanonicalLoopInfoType cliType =
3260 
3261  // Parse (optional) omp.cli identifier
3263  SmallVector<mlir::Value, 1> cliOperand;
3264  if (!parser.parseOptionalLParen()) {
3265  if (parser.parseOperand(cli) ||
3266  parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3267  return failure();
3268  }
3269 
3270  // We derive the type of tripCount from inductionVariable. MLIR requires the
3271  // type of tripCount to be known when calling resolveOperand so we have parse
3272  // the type before processing the inductionVariable.
3273  OpAsmParser::Argument inductionVariable;
3275  if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3276  parser.parseKeyword("in") || parser.parseKeyword("range") ||
3277  parser.parseLParen() || parser.parseOperand(tripcount) ||
3278  parser.parseRParen() ||
3279  parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3280  return failure();
3281 
3282  // Parse the loop body.
3283  Region *region = result.addRegion();
3284  if (parser.parseRegion(*region, {inductionVariable}))
3285  return failure();
3286 
3287  // We parsed the cli operand forst, but because it is optional, it must be
3288  // last in the operand list.
3289  result.operands.append(cliOperand);
3290 
3291  // Parse the optional attribute list.
3292  if (parser.parseOptionalAttrDict(result.attributes))
3293  return failure();
3294 
3295  return mlir::success();
3296 }
3297 
3298 LogicalResult CanonicalLoopOp::verify() {
3299  // The region's entry must accept the induction variable
3300  // It can also be empty if just created
3301  if (!getRegion().empty()) {
3302  Region &region = getRegion();
3303  if (region.getNumArguments() != 1)
3304  return emitOpError(
3305  "Canonical loop region must have exactly one argument");
3306 
3307  if (getInductionVar().getType() != getTripCount().getType())
3308  return emitOpError(
3309  "Region argument must be the same type as the trip count");
3310  }
3311 
3312  return success();
3313 }
3314 
3315 Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3316 
3317 std::pair<unsigned, unsigned>
3318 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3319  // No applyees
3320  return {0, 0};
3321 }
3322 
3323 std::pair<unsigned, unsigned>
3324 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3325  return getODSOperandIndexAndLength(odsIndex_cli);
3326 }
3327 
3328 //===----------------------------------------------------------------------===//
3329 // UnrollHeuristicOp
3330 //===----------------------------------------------------------------------===//
3331 
3332 void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3333  ::mlir::OperationState &odsState,
3334  ::mlir::Value cli) {
3335  odsState.addOperands(cli);
3336 }
3337 
3339  p << '(' << getApplyee() << ')';
3340 
3341  p.printOptionalAttrDict((*this)->getAttrs());
3342 }
3343 
3344 mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3345  ::mlir::OperationState &result) {
3346  auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3347 
3348  if (parser.parseLParen())
3349  return failure();
3350 
3352  if (parser.parseOperand(applyee) ||
3353  parser.resolveOperand(applyee, cliType, result.operands))
3354  return failure();
3355 
3356  if (parser.parseRParen())
3357  return failure();
3358 
3359  // Optional output loop (full unrolling has none)
3360  if (!parser.parseOptionalArrow()) {
3361  if (parser.parseLParen() || parser.parseRParen())
3362  return failure();
3363  }
3364 
3365  // Parse the optional attribute list.
3366  if (parser.parseOptionalAttrDict(result.attributes))
3367  return failure();
3368 
3369  return mlir::success();
3370 }
3371 
3372 std::pair<unsigned, unsigned>
3373 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3374  return getODSOperandIndexAndLength(odsIndex_applyee);
3375 }
3376 
3377 std::pair<unsigned, unsigned>
3378 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3379  return {0, 0};
3380 }
3381 
3382 //===----------------------------------------------------------------------===//
3383 // Critical construct (2.17.1)
3384 //===----------------------------------------------------------------------===//
3385 
3386 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3387  const CriticalDeclareOperands &clauses) {
3388  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3389 }
3390 
3391 LogicalResult CriticalDeclareOp::verify() {
3392  return verifySynchronizationHint(*this, getHint());
3393 }
3394 
3395 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3396  if (getNameAttr()) {
3397  SymbolRefAttr symbolRef = getNameAttr();
3398  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3399  *this, symbolRef);
3400  if (!decl) {
3401  return emitOpError() << "expected symbol reference " << symbolRef
3402  << " to point to a critical declaration";
3403  }
3404  }
3405 
3406  return success();
3407 }
3408 
3409 //===----------------------------------------------------------------------===//
3410 // Ordered construct
3411 //===----------------------------------------------------------------------===//
3412 
3413 static LogicalResult verifyOrderedParent(Operation &op) {
3414  bool hasRegion = op.getNumRegions() > 0;
3415  auto loopOp = op.getParentOfType<LoopNestOp>();
3416  if (!loopOp) {
3417  if (hasRegion)
3418  return success();
3419 
3420  // TODO: Consider if this needs to be the case only for the standalone
3421  // variant of the ordered construct.
3422  return op.emitOpError() << "must be nested inside of a loop";
3423  }
3424 
3425  Operation *wrapper = loopOp->getParentOp();
3426  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3427  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3428  if (!orderedAttr)
3429  return op.emitOpError() << "the enclosing worksharing-loop region must "
3430  "have an ordered clause";
3431 
3432  if (hasRegion && orderedAttr.getInt() != 0)
3433  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3434  "have a parameter present";
3435 
3436  if (!hasRegion && orderedAttr.getInt() == 0)
3437  return op.emitOpError() << "the enclosing loop's ordered clause must "
3438  "have a parameter present";
3439  } else if (!isa<SimdOp>(wrapper)) {
3440  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3441  "or worksharing simd loop";
3442  }
3443  return success();
3444 }
3445 
3446 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3447  const OrderedOperands &clauses) {
3448  OrderedOp::build(builder, state, clauses.doacrossDependType,
3449  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3450 }
3451 
3452 LogicalResult OrderedOp::verify() {
3453  if (failed(verifyOrderedParent(**this)))
3454  return failure();
3455 
3456  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3457  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3458  return emitOpError() << "number of variables in depend clause does not "
3459  << "match number of iteration variables in the "
3460  << "doacross loop";
3461 
3462  return success();
3463 }
3464 
3465 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3466  const OrderedRegionOperands &clauses) {
3467  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3468 }
3469 
3470 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3471 
3472 //===----------------------------------------------------------------------===//
3473 // TaskwaitOp
3474 //===----------------------------------------------------------------------===//
3475 
3476 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3477  const TaskwaitOperands &clauses) {
3478  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3479  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3480  /*depend_vars=*/{}, /*nowait=*/nullptr);
3481 }
3482 
3483 //===----------------------------------------------------------------------===//
3484 // Verifier for AtomicReadOp
3485 //===----------------------------------------------------------------------===//
3486 
3487 LogicalResult AtomicReadOp::verify() {
3488  if (verifyCommon().failed())
3489  return mlir::failure();
3490 
3491  if (auto mo = getMemoryOrder()) {
3492  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3493  *mo == ClauseMemoryOrderKind::Release) {
3494  return emitError(
3495  "memory-order must not be acq_rel or release for atomic reads");
3496  }
3497  }
3498  return verifySynchronizationHint(*this, getHint());
3499 }
3500 
3501 //===----------------------------------------------------------------------===//
3502 // Verifier for AtomicWriteOp
3503 //===----------------------------------------------------------------------===//
3504 
3505 LogicalResult AtomicWriteOp::verify() {
3506  if (verifyCommon().failed())
3507  return mlir::failure();
3508 
3509  if (auto mo = getMemoryOrder()) {
3510  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3511  *mo == ClauseMemoryOrderKind::Acquire) {
3512  return emitError(
3513  "memory-order must not be acq_rel or acquire for atomic writes");
3514  }
3515  }
3516  return verifySynchronizationHint(*this, getHint());
3517 }
3518 
3519 //===----------------------------------------------------------------------===//
3520 // Verifier for AtomicUpdateOp
3521 //===----------------------------------------------------------------------===//
3522 
3523 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3524  PatternRewriter &rewriter) {
3525  if (op.isNoOp()) {
3526  rewriter.eraseOp(op);
3527  return success();
3528  }
3529  if (Value writeVal = op.getWriteOpVal()) {
3530  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3531  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3532  return success();
3533  }
3534  return failure();
3535 }
3536 
3537 LogicalResult AtomicUpdateOp::verify() {
3538  if (verifyCommon().failed())
3539  return mlir::failure();
3540 
3541  if (auto mo = getMemoryOrder()) {
3542  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3543  *mo == ClauseMemoryOrderKind::Acquire) {
3544  return emitError(
3545  "memory-order must not be acq_rel or acquire for atomic updates");
3546  }
3547  }
3548 
3549  return verifySynchronizationHint(*this, getHint());
3550 }
3551 
3552 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3553 
3554 //===----------------------------------------------------------------------===//
3555 // Verifier for AtomicCaptureOp
3556 //===----------------------------------------------------------------------===//
3557 
3558 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3559  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3560  return op;
3561  return dyn_cast<AtomicReadOp>(getSecondOp());
3562 }
3563 
3564 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3565  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3566  return op;
3567  return dyn_cast<AtomicWriteOp>(getSecondOp());
3568 }
3569 
3570 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3571  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3572  return op;
3573  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3574 }
3575 
3576 LogicalResult AtomicCaptureOp::verify() {
3577  return verifySynchronizationHint(*this, getHint());
3578 }
3579 
3580 LogicalResult AtomicCaptureOp::verifyRegions() {
3581  if (verifyRegionsCommon().failed())
3582  return mlir::failure();
3583 
3584  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3585  return emitOpError(
3586  "operations inside capture region must not have hint clause");
3587 
3588  if (getFirstOp()->getAttr("memory_order") ||
3589  getSecondOp()->getAttr("memory_order"))
3590  return emitOpError(
3591  "operations inside capture region must not have memory_order clause");
3592  return success();
3593 }
3594 
3595 //===----------------------------------------------------------------------===//
3596 // CancelOp
3597 //===----------------------------------------------------------------------===//
3598 
3599 void CancelOp::build(OpBuilder &builder, OperationState &state,
3600  const CancelOperands &clauses) {
3601  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3602 }
3603 
3605  Operation *parent = thisOp->getParentOp();
3606  while (parent) {
3607  if (parent->getDialect() == thisOp->getDialect())
3608  return parent;
3609  parent = parent->getParentOp();
3610  }
3611  return nullptr;
3612 }
3613 
3614 LogicalResult CancelOp::verify() {
3615  ClauseCancellationConstructType cct = getCancelDirective();
3616  // The next OpenMP operation in the chain of parents
3617  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3618  if (!structuralParent)
3619  return emitOpError() << "Orphaned cancel construct";
3620 
3621  if ((cct == ClauseCancellationConstructType::Parallel) &&
3622  !mlir::isa<ParallelOp>(structuralParent)) {
3623  return emitOpError() << "cancel parallel must appear "
3624  << "inside a parallel region";
3625  }
3626  if (cct == ClauseCancellationConstructType::Loop) {
3627  // structural parent will be omp.loop_nest, directly nested inside
3628  // omp.wsloop
3629  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3630 
3631  if (!wsloopOp) {
3632  return emitOpError()
3633  << "cancel loop must appear inside a worksharing-loop region";
3634  }
3635  if (wsloopOp.getNowaitAttr()) {
3636  return emitError() << "A worksharing construct that is canceled "
3637  << "must not have a nowait clause";
3638  }
3639  if (wsloopOp.getOrderedAttr()) {
3640  return emitError() << "A worksharing construct that is canceled "
3641  << "must not have an ordered clause";
3642  }
3643 
3644  } else if (cct == ClauseCancellationConstructType::Sections) {
3645  // structural parent will be an omp.section, directly nested inside
3646  // omp.sections
3647  auto sectionsOp =
3648  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3649  if (!sectionsOp) {
3650  return emitOpError() << "cancel sections must appear "
3651  << "inside a sections region";
3652  }
3653  if (sectionsOp.getNowait()) {
3654  return emitError() << "A sections construct that is canceled "
3655  << "must not have a nowait clause";
3656  }
3657  }
3658  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3659  (!mlir::isa<omp::TaskOp>(structuralParent) &&
3660  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
3661  return emitOpError() << "cancel taskgroup must appear "
3662  << "inside a task region";
3663  }
3664  return success();
3665 }
3666 
3667 //===----------------------------------------------------------------------===//
3668 // CancellationPointOp
3669 //===----------------------------------------------------------------------===//
3670 
3671 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3672  const CancellationPointOperands &clauses) {
3673  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3674 }
3675 
3676 LogicalResult CancellationPointOp::verify() {
3677  ClauseCancellationConstructType cct = getCancelDirective();
3678  // The next OpenMP operation in the chain of parents
3679  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3680  if (!structuralParent)
3681  return emitOpError() << "Orphaned cancellation point";
3682 
3683  if ((cct == ClauseCancellationConstructType::Parallel) &&
3684  !mlir::isa<ParallelOp>(structuralParent)) {
3685  return emitOpError() << "cancellation point parallel must appear "
3686  << "inside a parallel region";
3687  }
3688  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3689  // find the wsloop
3690  if ((cct == ClauseCancellationConstructType::Loop) &&
3691  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3692  return emitOpError() << "cancellation point loop must appear "
3693  << "inside a worksharing-loop region";
3694  }
3695  if ((cct == ClauseCancellationConstructType::Sections) &&
3696  !mlir::isa<omp::SectionOp>(structuralParent)) {
3697  return emitOpError() << "cancellation point sections must appear "
3698  << "inside a sections region";
3699  }
3700  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3701  !mlir::isa<omp::TaskOp>(structuralParent)) {
3702  return emitOpError() << "cancellation point taskgroup must appear "
3703  << "inside a task region";
3704  }
3705  return success();
3706 }
3707 
3708 //===----------------------------------------------------------------------===//
3709 // MapBoundsOp
3710 //===----------------------------------------------------------------------===//
3711 
3712 LogicalResult MapBoundsOp::verify() {
3713  auto extent = getExtent();
3714  auto upperbound = getUpperBound();
3715  if (!extent && !upperbound)
3716  return emitError("expected extent or upperbound.");
3717  return success();
3718 }
3719 
3720 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3721  TypeRange /*result_types*/, StringAttr symName,
3722  TypeAttr type) {
3723  PrivateClauseOp::build(
3724  odsBuilder, odsState, symName, type,
3726  DataSharingClauseType::Private));
3727 }
3728 
3729 LogicalResult PrivateClauseOp::verifyRegions() {
3730  Type argType = getArgType();
3731  auto verifyTerminator = [&](Operation *terminator,
3732  bool yieldsValue) -> LogicalResult {
3733  if (!terminator->getBlock()->getSuccessors().empty())
3734  return success();
3735 
3736  if (!llvm::isa<YieldOp>(terminator))
3737  return mlir::emitError(terminator->getLoc())
3738  << "expected exit block terminator to be an `omp.yield` op.";
3739 
3740  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3741  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3742 
3743  if (!yieldsValue) {
3744  if (yieldedTypes.empty())
3745  return success();
3746 
3747  return mlir::emitError(terminator->getLoc())
3748  << "Did not expect any values to be yielded.";
3749  }
3750 
3751  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3752  return success();
3753 
3754  auto error = mlir::emitError(yieldOp.getLoc())
3755  << "Invalid yielded value. Expected type: " << argType
3756  << ", got: ";
3757 
3758  if (yieldedTypes.empty())
3759  error << "None";
3760  else
3761  error << yieldedTypes;
3762 
3763  return error;
3764  };
3765 
3766  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3767  StringRef regionName,
3768  bool yieldsValue) -> LogicalResult {
3769  assert(!region.empty());
3770 
3771  if (region.getNumArguments() != expectedNumArgs)
3772  return mlir::emitError(region.getLoc())
3773  << "`" << regionName << "`: "
3774  << "expected " << expectedNumArgs
3775  << " region arguments, got: " << region.getNumArguments();
3776 
3777  for (Block &block : region) {
3778  // MLIR will verify the absence of the terminator for us.
3779  if (!block.mightHaveTerminator())
3780  continue;
3781 
3782  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3783  return failure();
3784  }
3785 
3786  return success();
3787  };
3788 
3789  // Ensure all of the region arguments have the same type
3790  for (Region *region : getRegions())
3791  for (Type ty : region->getArgumentTypes())
3792  if (ty != argType)
3793  return emitError() << "Region argument type mismatch: got " << ty
3794  << " expected " << argType << ".";
3795 
3796  mlir::Region &initRegion = getInitRegion();
3797  if (!initRegion.empty() &&
3798  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3799  /*yieldsValue=*/true)))
3800  return failure();
3801 
3802  DataSharingClauseType dsType = getDataSharingType();
3803 
3804  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3805  return emitError("`private` clauses do not require a `copy` region.");
3806 
3807  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3808  return emitError(
3809  "`firstprivate` clauses require at least a `copy` region.");
3810 
3811  if (dsType == DataSharingClauseType::FirstPrivate &&
3812  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3813  /*yieldsValue=*/true)))
3814  return failure();
3815 
3816  if (!getDeallocRegion().empty() &&
3817  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3818  /*yieldsValue=*/false)))
3819  return failure();
3820 
3821  return success();
3822 }
3823 
3824 //===----------------------------------------------------------------------===//
3825 // Spec 5.2: Masked construct (10.5)
3826 //===----------------------------------------------------------------------===//
3827 
3828 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3829  const MaskedOperands &clauses) {
3830  MaskedOp::build(builder, state, clauses.filteredThreadId);
3831 }
3832 
3833 //===----------------------------------------------------------------------===//
3834 // Spec 5.2: Scan construct (5.6)
3835 //===----------------------------------------------------------------------===//
3836 
3837 void ScanOp::build(OpBuilder &builder, OperationState &state,
3838  const ScanOperands &clauses) {
3839  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3840 }
3841 
3842 LogicalResult ScanOp::verify() {
3843  if (hasExclusiveVars() == hasInclusiveVars())
3844  return emitError(
3845  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3846  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3847  if (parentWsLoopOp.getReductionModAttr() &&
3848  parentWsLoopOp.getReductionModAttr().getValue() ==
3849  ReductionModifier::inscan)
3850  return success();
3851  }
3852  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3853  if (parentSimdOp.getReductionModAttr() &&
3854  parentSimdOp.getReductionModAttr().getValue() ==
3855  ReductionModifier::inscan)
3856  return success();
3857  }
3858  return emitError("SCAN directive needs to be enclosed within a parent "
3859  "worksharing loop construct or SIMD construct with INSCAN "
3860  "reduction modifier");
3861 }
3862 
3863 /// Verifies align clause in allocate directive
3864 
3865 LogicalResult AllocateDirOp::verify() {
3866  std::optional<uint64_t> align = this->getAlign();
3867 
3868  if (align.has_value()) {
3869  if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
3870  return emitError() << "ALIGN value : " << align.value()
3871  << " must be power of 2";
3872  }
3873 
3874  return success();
3875 }
3876 
3877 #define GET_ATTRDEF_CLASSES
3878 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3879 
3880 #define GET_OP_CLASSES
3881 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3882 
3883 #define GET_TYPEDEF_CLASSES
3884 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:755
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1288
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Definition: Diagnostics.h:228
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:352
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:769
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:306
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:881
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.