MLIR  21.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35 #include <cstddef>
36 #include <iterator>
37 #include <optional>
38 #include <variant>
39 
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::omp;
47 
48 static ArrayAttr makeArrayAttr(MLIRContext *context,
50  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51 }
52 
53 static DenseBoolArrayAttr
55  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56 }
57 
58 namespace {
59 struct MemRefPointerLikeModel
60  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61  MemRefType> {
62  Type getElementType(Type pointer) const {
63  return llvm::cast<MemRefType>(pointer).getElementType();
64  }
65 };
66 
67 struct LLVMPointerPointerLikeModel
68  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69  LLVM::LLVMPointerType> {
70  Type getElementType(Type pointer) const { return Type(); }
71 };
72 } // namespace
73 
74 void OpenMPDialect::initialize() {
75  addOperations<
76 #define GET_OP_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78  >();
79  addAttributes<
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82  >();
83  addTypes<
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86  >();
87 
88  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89 
90  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92  *getContext());
93 
94  // Attach default offload module interface to module op to access
95  // offload functionality through
96  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97  *getContext());
98 
99  // Attach default declare target interfaces to operations which can be marked
100  // as declare target (Global Operations and Functions/Subroutines in dialects
101  // that Fortran (or other languages that lower to MLIR) translates too
102  mlir::LLVM::GlobalOp::attachInterface<
104  *getContext());
105  mlir::LLVM::LLVMFuncOp::attachInterface<
107  *getContext());
108  mlir::func::FuncOp::attachInterface<
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parser and printer for Allocate Clause
114 //===----------------------------------------------------------------------===//
115 
116 /// Parse an allocate clause with allocators and a list of operands with types.
117 ///
118 /// allocate-operand-list :: = allocate-operand |
119 /// allocator-operand `,` allocate-operand-list
120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121 /// ssa-id-and-type ::= ssa-id `:` type
122 static ParseResult parseAllocateAndAllocator(
123  OpAsmParser &parser,
125  SmallVectorImpl<Type> &allocateTypes,
127  SmallVectorImpl<Type> &allocatorTypes) {
128 
129  return parser.parseCommaSeparatedList([&]() {
131  Type type;
132  if (parser.parseOperand(operand) || parser.parseColonType(type))
133  return failure();
134  allocatorVars.push_back(operand);
135  allocatorTypes.push_back(type);
136  if (parser.parseArrow())
137  return failure();
138  if (parser.parseOperand(operand) || parser.parseColonType(type))
139  return failure();
140 
141  allocateVars.push_back(operand);
142  allocateTypes.push_back(type);
143  return success();
144  });
145 }
146 
147 /// Print allocate clause
149  OperandRange allocateVars,
150  TypeRange allocateTypes,
151  OperandRange allocatorVars,
152  TypeRange allocatorTypes) {
153  for (unsigned i = 0; i < allocateVars.size(); ++i) {
154  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157  }
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parser and printer for a clause attribute (StringEnumAttr)
162 //===----------------------------------------------------------------------===//
163 
164 template <typename ClauseAttr>
165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167  StringRef enumStr;
168  SMLoc loc = parser.getCurrentLocation();
169  if (parser.parseKeyword(&enumStr))
170  return failure();
171  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172  attr = ClauseAttr::get(parser.getContext(), *enumValue);
173  return success();
174  }
175  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
176 }
177 
178 template <typename ClauseAttr>
179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180  p << stringifyEnum(attr.getValue());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Parser and printer for Linear Clause
185 //===----------------------------------------------------------------------===//
186 
187 /// linear ::= `linear` `(` linear-list `)`
188 /// linear-list := linear-val | linear-val linear-list
189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
190 static ParseResult parseLinearClause(
191  OpAsmParser &parser,
193  SmallVectorImpl<Type> &linearTypes,
195  return parser.parseCommaSeparatedList([&]() {
197  Type type;
199  if (parser.parseOperand(var) || parser.parseEqual() ||
200  parser.parseOperand(stepVar) || parser.parseColonType(type))
201  return failure();
202 
203  linearVars.push_back(var);
204  linearTypes.push_back(type);
205  linearStepVars.push_back(stepVar);
206  return success();
207  });
208 }
209 
210 /// Print Linear Clause
212  ValueRange linearVars, TypeRange linearTypes,
213  ValueRange linearStepVars) {
214  size_t linearVarsSize = linearVars.size();
215  for (unsigned i = 0; i < linearVarsSize; ++i) {
216  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217  p << linearVars[i];
218  if (linearStepVars.size() > i)
219  p << " = " << linearStepVars[i];
220  p << " : " << linearVars[i].getType() << separator;
221  }
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Verifier for Nontemporal Clause
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verifyNontemporalClause(Operation *op,
229  OperandRange nontemporalVars) {
230 
231  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232  DenseSet<Value> nontemporalItems;
233  for (const auto &it : nontemporalVars)
234  if (!nontemporalItems.insert(it).second)
235  return op->emitOpError() << "nontemporal variable used more than once";
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Parser, verifier and printer for Aligned Clause
242 //===----------------------------------------------------------------------===//
243 static LogicalResult verifyAlignedClause(Operation *op,
244  std::optional<ArrayAttr> alignments,
245  OperandRange alignedVars) {
246  // Check if number of alignment values equals to number of aligned variables
247  if (!alignedVars.empty()) {
248  if (!alignments || alignments->size() != alignedVars.size())
249  return op->emitOpError()
250  << "expected as many alignment values as aligned variables";
251  } else {
252  if (alignments)
253  return op->emitOpError() << "unexpected alignment values attribute";
254  return success();
255  }
256 
257  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258  DenseSet<Value> alignedItems;
259  for (auto it : alignedVars)
260  if (!alignedItems.insert(it).second)
261  return op->emitOpError() << "aligned variable used more than once";
262 
263  if (!alignments)
264  return success();
265 
266  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267  for (unsigned i = 0; i < (*alignments).size(); ++i) {
268  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269  if (intAttr.getValue().sle(0))
270  return op->emitOpError() << "alignment should be greater than 0";
271  } else {
272  return op->emitOpError() << "expected integer alignment";
273  }
274  }
275 
276  return success();
277 }
278 
279 /// aligned ::= `aligned` `(` aligned-list `)`
280 /// aligned-list := aligned-val | aligned-val aligned-list
281 /// aligned-val := ssa-id-and-type `->` alignment
282 static ParseResult
285  SmallVectorImpl<Type> &alignedTypes,
286  ArrayAttr &alignmentsAttr) {
287  SmallVector<Attribute> alignmentVec;
288  if (failed(parser.parseCommaSeparatedList([&]() {
289  if (parser.parseOperand(alignedVars.emplace_back()) ||
290  parser.parseColonType(alignedTypes.emplace_back()) ||
291  parser.parseArrow() ||
292  parser.parseAttribute(alignmentVec.emplace_back())) {
293  return failure();
294  }
295  return success();
296  })))
297  return failure();
298  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300  return success();
301 }
302 
303 /// Print Aligned Clause
305  ValueRange alignedVars, TypeRange alignedTypes,
306  std::optional<ArrayAttr> alignments) {
307  for (unsigned i = 0; i < alignedVars.size(); ++i) {
308  if (i != 0)
309  p << ", ";
310  p << alignedVars[i] << " : " << alignedVars[i].getType();
311  p << " -> " << (*alignments)[i];
312  }
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Parser, printer and verifier for Schedule Clause
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult
321  SmallVectorImpl<SmallString<12>> &modifiers) {
322  if (modifiers.size() > 2)
323  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
324  for (const auto &mod : modifiers) {
325  // Translate the string. If it has no value, then it was not a valid
326  // modifier!
327  auto symbol = symbolizeScheduleModifier(mod);
328  if (!symbol)
329  return parser.emitError(parser.getNameLoc())
330  << " unknown modifier type: " << mod;
331  }
332 
333  // If we have one modifier that is "simd", then stick a "none" modiifer in
334  // index 0.
335  if (modifiers.size() == 1) {
336  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337  modifiers.push_back(modifiers[0]);
338  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339  }
340  } else if (modifiers.size() == 2) {
341  // If there are two modifier:
342  // First modifier should not be simd, second one should be simd
343  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345  return parser.emitError(parser.getNameLoc())
346  << " incorrect modifier order";
347  }
348  return success();
349 }
350 
351 /// schedule ::= `schedule` `(` sched-list `)`
352 /// sched-list ::= sched-val | sched-val sched-list |
353 /// sched-val `,` sched-modifier
354 /// sched-val ::= sched-with-chunk | sched-wo-chunk
355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357 /// sched-wo-chunk ::= `auto` | `runtime`
358 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
359 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
360 static ParseResult
361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364  Type &chunkType) {
365  StringRef keyword;
366  if (parser.parseKeyword(&keyword))
367  return failure();
368  std::optional<mlir::omp::ClauseScheduleKind> schedule =
369  symbolizeClauseScheduleKind(keyword);
370  if (!schedule)
371  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374  switch (*schedule) {
375  case ClauseScheduleKind::Static:
376  case ClauseScheduleKind::Dynamic:
377  case ClauseScheduleKind::Guided:
378  if (succeeded(parser.parseOptionalEqual())) {
379  chunkSize = OpAsmParser::UnresolvedOperand{};
380  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381  return failure();
382  } else {
383  chunkSize = std::nullopt;
384  }
385  break;
386  case ClauseScheduleKind::Auto:
388  chunkSize = std::nullopt;
389  }
390 
391  // If there is a comma, we have one or more modifiers..
392  SmallVector<SmallString<12>> modifiers;
393  while (succeeded(parser.parseOptionalComma())) {
394  StringRef mod;
395  if (parser.parseKeyword(&mod))
396  return failure();
397  modifiers.push_back(mod);
398  }
399 
400  if (verifyScheduleModifiers(parser, modifiers))
401  return failure();
402 
403  if (!modifiers.empty()) {
404  SMLoc loc = parser.getCurrentLocation();
405  if (std::optional<ScheduleModifier> mod =
406  symbolizeScheduleModifier(modifiers[0])) {
407  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408  } else {
409  return parser.emitError(loc, "invalid schedule modifier");
410  }
411  // Only SIMD attribute is allowed here!
412  if (modifiers.size() > 1) {
413  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415  }
416  }
417 
418  return success();
419 }
420 
421 /// Print schedule clause
423  ClauseScheduleKindAttr scheduleKind,
424  ScheduleModifierAttr scheduleMod,
425  UnitAttr scheduleSimd, Value scheduleChunk,
426  Type scheduleChunkType) {
427  p << stringifyClauseScheduleKind(scheduleKind.getValue());
428  if (scheduleChunk)
429  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430  if (scheduleMod)
431  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432  if (scheduleSimd)
433  p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser and printer for Order Clause
438 //===----------------------------------------------------------------------===//
439 
440 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
441 // order-modifier ::= reproducible | unconstrained
442 static ParseResult parseOrderClause(OpAsmParser &parser,
443  ClauseOrderKindAttr &order,
444  OrderModifierAttr &orderMod) {
445  StringRef enumStr;
446  SMLoc loc = parser.getCurrentLocation();
447  if (parser.parseKeyword(&enumStr))
448  return failure();
449  if (std::optional<OrderModifier> enumValue =
450  symbolizeOrderModifier(enumStr)) {
451  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452  if (parser.parseOptionalColon())
453  return failure();
454  loc = parser.getCurrentLocation();
455  if (parser.parseKeyword(&enumStr))
456  return failure();
457  }
458  if (std::optional<ClauseOrderKind> enumValue =
459  symbolizeClauseOrderKind(enumStr)) {
460  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461  return success();
462  }
463  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
464 }
465 
467  ClauseOrderKindAttr order,
468  OrderModifierAttr orderMod) {
469  if (orderMod)
470  p << stringifyOrderModifier(orderMod.getValue()) << ":";
471  if (order)
472  p << stringifyClauseOrderKind(order.getValue());
473 }
474 
475 template <typename ClauseTypeAttr, typename ClauseType>
476 static ParseResult
477 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478  std::optional<OpAsmParser::UnresolvedOperand> &operand,
479  Type &operandType,
480  std::optional<ClauseType> (*symbolizeClause)(StringRef),
481  StringRef clauseName) {
482  StringRef enumStr;
483  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
484  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
486  if (parser.parseComma())
487  return failure();
488  } else {
489  return parser.emitError(parser.getCurrentLocation())
490  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
491  ;
492  }
493  }
494 
496  if (succeeded(parser.parseOperand(var))) {
497  operand = var;
498  } else {
499  return parser.emitError(parser.getCurrentLocation())
500  << "expected " << clauseName << " operand";
501  }
502 
503  if (operand.has_value()) {
504  if (parser.parseColonType(operandType))
505  return failure();
506  }
507 
508  return success();
509 }
510 
511 template <typename ClauseTypeAttr, typename ClauseType>
512 static void
514  ClauseTypeAttr prescriptiveness, Value operand,
515  mlir::Type operandType,
516  StringRef (*stringifyClauseType)(ClauseType)) {
517 
518  if (prescriptiveness)
519  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
520 
521  if (operand)
522  p << operand << ": " << operandType;
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // Parser and printer for grainsize Clause
527 //===----------------------------------------------------------------------===//
528 
529 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530 static ParseResult
531 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533  Type &grainsizeType) {
534  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535  parser, grainsizeMod, grainsize, grainsizeType,
536  &symbolizeClauseGrainsizeType, "grainsize");
537 }
538 
540  ClauseGrainsizeTypeAttr grainsizeMod,
541  Value grainsize, mlir::Type grainsizeType) {
542  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543  p, op, grainsizeMod, grainsize, grainsizeType,
544  &stringifyClauseGrainsizeType);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Parser and printer for num_tasks Clause
549 //===----------------------------------------------------------------------===//
550 
551 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552 static ParseResult
553 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555  Type &numTasksType) {
556  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558  "num_tasks");
559 }
560 
562  ClauseNumTasksTypeAttr numTasksMod,
563  Value numTasks, mlir::Type numTasksType) {
564  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Parsers for operations including clauses that define entry block arguments.
570 //===----------------------------------------------------------------------===//
571 
572 namespace {
573 struct MapParseArgs {
575  SmallVectorImpl<Type> &types;
577  SmallVectorImpl<Type> &types)
578  : vars(vars), types(types) {}
579 };
580 struct PrivateParseArgs {
583  ArrayAttr &syms;
584  UnitAttr &needsBarrier;
585  DenseI64ArrayAttr *mapIndices;
587  SmallVectorImpl<Type> &types, ArrayAttr &syms,
588  UnitAttr &needsBarrier,
589  DenseI64ArrayAttr *mapIndices = nullptr)
590  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
591  mapIndices(mapIndices) {}
592 };
593 
594 struct ReductionParseArgs {
596  SmallVectorImpl<Type> &types;
597  DenseBoolArrayAttr &byref;
598  ArrayAttr &syms;
599  ReductionModifierAttr *modifier;
600  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
602  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
603  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
604 };
605 
606 struct AllRegionParseArgs {
607  std::optional<MapParseArgs> hasDeviceAddrArgs;
608  std::optional<MapParseArgs> hostEvalArgs;
609  std::optional<ReductionParseArgs> inReductionArgs;
610  std::optional<MapParseArgs> mapArgs;
611  std::optional<PrivateParseArgs> privateArgs;
612  std::optional<ReductionParseArgs> reductionArgs;
613  std::optional<ReductionParseArgs> taskReductionArgs;
614  std::optional<MapParseArgs> useDeviceAddrArgs;
615  std::optional<MapParseArgs> useDevicePtrArgs;
616 };
617 } // namespace
618 
619 static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
620  return "private_barrier";
621 }
622 
623 static ParseResult parseClauseWithRegionArgs(
624  OpAsmParser &parser,
626  SmallVectorImpl<Type> &types,
627  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
628  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
629  DenseBoolArrayAttr *byref = nullptr,
630  ReductionModifierAttr *modifier = nullptr,
631  UnitAttr *needsBarrier = nullptr) {
632  SmallVector<SymbolRefAttr> symbolVec;
633  SmallVector<int64_t> mapIndicesVec;
634  SmallVector<bool> isByRefVec;
635  unsigned regionArgOffset = regionPrivateArgs.size();
636 
637  if (parser.parseLParen())
638  return failure();
639 
640  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
641  StringRef enumStr;
642  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
643  parser.parseComma())
644  return failure();
645  std::optional<ReductionModifier> enumValue =
646  symbolizeReductionModifier(enumStr);
647  if (!enumValue.has_value())
648  return failure();
649  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
650  if (!*modifier)
651  return failure();
652  }
653 
654  if (parser.parseCommaSeparatedList([&]() {
655  if (byref)
656  isByRefVec.push_back(
657  parser.parseOptionalKeyword("byref").succeeded());
658 
659  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
660  return failure();
661 
662  if (parser.parseOperand(operands.emplace_back()) ||
663  parser.parseArrow() ||
664  parser.parseArgument(regionPrivateArgs.emplace_back()))
665  return failure();
666 
667  if (mapIndices) {
668  if (parser.parseOptionalLSquare().succeeded()) {
669  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
670  parser.parseInteger(mapIndicesVec.emplace_back()) ||
671  parser.parseRSquare())
672  return failure();
673  } else {
674  mapIndicesVec.push_back(-1);
675  }
676  }
677 
678  return success();
679  }))
680  return failure();
681 
682  if (parser.parseColon())
683  return failure();
684 
685  if (parser.parseCommaSeparatedList([&]() {
686  if (parser.parseType(types.emplace_back()))
687  return failure();
688 
689  return success();
690  }))
691  return failure();
692 
693  if (operands.size() != types.size())
694  return failure();
695 
696  if (parser.parseRParen())
697  return failure();
698 
699  if (needsBarrier) {
701  .succeeded())
702  *needsBarrier = mlir::UnitAttr::get(parser.getContext());
703  }
704 
705  auto *argsBegin = regionPrivateArgs.begin();
706  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
707  argsBegin + regionArgOffset + types.size());
708  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
709  prv.type = type;
710  }
711 
712  if (symbols) {
713  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
714  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
715  }
716 
717  if (!mapIndicesVec.empty())
718  *mapIndices =
719  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
720 
721  if (byref)
722  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
723 
724  return success();
725 }
726 
727 static ParseResult parseBlockArgClause(
728  OpAsmParser &parser,
730  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
731  if (succeeded(parser.parseOptionalKeyword(keyword))) {
732  if (!mapArgs)
733  return failure();
734 
735  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
736  entryBlockArgs)))
737  return failure();
738  }
739  return success();
740 }
741 
742 static ParseResult parseBlockArgClause(
743  OpAsmParser &parser,
745  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
746  if (succeeded(parser.parseOptionalKeyword(keyword))) {
747  if (!privateArgs)
748  return failure();
749 
750  if (failed(parseClauseWithRegionArgs(
751  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
752  &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
753  /*modifier=*/nullptr, &privateArgs->needsBarrier)))
754  return failure();
755  }
756  return success();
757 }
758 
759 static ParseResult parseBlockArgClause(
760  OpAsmParser &parser,
762  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
763  if (succeeded(parser.parseOptionalKeyword(keyword))) {
764  if (!reductionArgs)
765  return failure();
766  if (failed(parseClauseWithRegionArgs(
767  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
768  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
769  reductionArgs->modifier)))
770  return failure();
771  }
772  return success();
773 }
774 
775 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
776  AllRegionParseArgs args) {
778 
779  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
780  args.hasDeviceAddrArgs)))
781  return parser.emitError(parser.getCurrentLocation())
782  << "invalid `has_device_addr` format";
783 
784  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
785  args.hostEvalArgs)))
786  return parser.emitError(parser.getCurrentLocation())
787  << "invalid `host_eval` format";
788 
789  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
790  args.inReductionArgs)))
791  return parser.emitError(parser.getCurrentLocation())
792  << "invalid `in_reduction` format";
793 
794  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
795  args.mapArgs)))
796  return parser.emitError(parser.getCurrentLocation())
797  << "invalid `map_entries` format";
798 
799  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
800  args.privateArgs)))
801  return parser.emitError(parser.getCurrentLocation())
802  << "invalid `private` format";
803 
804  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
805  args.reductionArgs)))
806  return parser.emitError(parser.getCurrentLocation())
807  << "invalid `reduction` format";
808 
809  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
810  args.taskReductionArgs)))
811  return parser.emitError(parser.getCurrentLocation())
812  << "invalid `task_reduction` format";
813 
814  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
815  args.useDeviceAddrArgs)))
816  return parser.emitError(parser.getCurrentLocation())
817  << "invalid `use_device_addr` format";
818 
819  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
820  args.useDevicePtrArgs)))
821  return parser.emitError(parser.getCurrentLocation())
822  << "invalid `use_device_addr` format";
823 
824  return parser.parseRegion(region, entryBlockArgs);
825 }
826 
827 // These parseXyz functions correspond to the custom<Xyz> definitions
828 // in the .td file(s).
829 static ParseResult parseTargetOpRegion(
830  OpAsmParser &parser, Region &region,
832  SmallVectorImpl<Type> &hasDeviceAddrTypes,
834  SmallVectorImpl<Type> &hostEvalTypes,
836  SmallVectorImpl<Type> &inReductionTypes,
837  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
839  SmallVectorImpl<Type> &mapTypes,
841  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
842  UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
843  AllRegionParseArgs args;
844  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
845  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
846  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
847  inReductionByref, inReductionSyms);
848  args.mapArgs.emplace(mapVars, mapTypes);
849  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
850  privateNeedsBarrier, &privateMaps);
851  return parseBlockArgRegion(parser, region, args);
852 }
853 
854 static ParseResult parseInReductionPrivateRegion(
855  OpAsmParser &parser, Region &region,
857  SmallVectorImpl<Type> &inReductionTypes,
858  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
860  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
861  UnitAttr &privateNeedsBarrier) {
862  AllRegionParseArgs args;
863  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
864  inReductionByref, inReductionSyms);
865  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
866  privateNeedsBarrier);
867  return parseBlockArgRegion(parser, region, args);
868 }
869 
871  OpAsmParser &parser, Region &region,
873  SmallVectorImpl<Type> &inReductionTypes,
874  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
876  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
877  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
879  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
880  ArrayAttr &reductionSyms) {
881  AllRegionParseArgs args;
882  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
883  inReductionByref, inReductionSyms);
884  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
885  privateNeedsBarrier);
886  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
887  reductionSyms, &reductionMod);
888  return parseBlockArgRegion(parser, region, args);
889 }
890 
891 static ParseResult parsePrivateRegion(
892  OpAsmParser &parser, Region &region,
894  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
895  UnitAttr &privateNeedsBarrier) {
896  AllRegionParseArgs args;
897  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
898  privateNeedsBarrier);
899  return parseBlockArgRegion(parser, region, args);
900 }
901 
902 static ParseResult parsePrivateReductionRegion(
903  OpAsmParser &parser, Region &region,
905  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
906  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
908  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
909  ArrayAttr &reductionSyms) {
910  AllRegionParseArgs args;
911  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
912  privateNeedsBarrier);
913  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
914  reductionSyms, &reductionMod);
915  return parseBlockArgRegion(parser, region, args);
916 }
917 
918 static ParseResult parseTaskReductionRegion(
919  OpAsmParser &parser, Region &region,
921  SmallVectorImpl<Type> &taskReductionTypes,
922  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
923  AllRegionParseArgs args;
924  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
925  taskReductionByref, taskReductionSyms);
926  return parseBlockArgRegion(parser, region, args);
927 }
928 
930  OpAsmParser &parser, Region &region,
932  SmallVectorImpl<Type> &useDeviceAddrTypes,
934  SmallVectorImpl<Type> &useDevicePtrTypes) {
935  AllRegionParseArgs args;
936  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
937  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
938  return parseBlockArgRegion(parser, region, args);
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // Printers for operations including clauses that define entry block arguments.
943 //===----------------------------------------------------------------------===//
944 
945 namespace {
946 struct MapPrintArgs {
947  ValueRange vars;
948  TypeRange types;
949  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
950 };
951 struct PrivatePrintArgs {
952  ValueRange vars;
953  TypeRange types;
954  ArrayAttr syms;
955  UnitAttr needsBarrier;
956  DenseI64ArrayAttr mapIndices;
957  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
958  UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
959  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
960  mapIndices(mapIndices) {}
961 };
962 struct ReductionPrintArgs {
963  ValueRange vars;
964  TypeRange types;
965  DenseBoolArrayAttr byref;
966  ArrayAttr syms;
967  ReductionModifierAttr modifier;
968  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
969  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
970  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
971 };
972 struct AllRegionPrintArgs {
973  std::optional<MapPrintArgs> hasDeviceAddrArgs;
974  std::optional<MapPrintArgs> hostEvalArgs;
975  std::optional<ReductionPrintArgs> inReductionArgs;
976  std::optional<MapPrintArgs> mapArgs;
977  std::optional<PrivatePrintArgs> privateArgs;
978  std::optional<ReductionPrintArgs> reductionArgs;
979  std::optional<ReductionPrintArgs> taskReductionArgs;
980  std::optional<MapPrintArgs> useDeviceAddrArgs;
981  std::optional<MapPrintArgs> useDevicePtrArgs;
982 };
983 } // namespace
984 
986  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
987  ValueRange argsSubrange, ValueRange operands, TypeRange types,
988  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
989  DenseBoolArrayAttr byref = nullptr,
990  ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
991  if (argsSubrange.empty())
992  return;
993 
994  p << clauseName << "(";
995 
996  if (modifier)
997  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
998 
999  if (!symbols) {
1000  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1001  symbols = ArrayAttr::get(ctx, values);
1002  }
1003 
1004  if (!mapIndices) {
1005  llvm::SmallVector<int64_t> values(operands.size(), -1);
1006  mapIndices = DenseI64ArrayAttr::get(ctx, values);
1007  }
1008 
1009  if (!byref) {
1010  mlir::SmallVector<bool> values(operands.size(), false);
1011  byref = DenseBoolArrayAttr::get(ctx, values);
1012  }
1013 
1014  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1015  mapIndices.asArrayRef(),
1016  byref.asArrayRef()),
1017  p, [&p](auto t) {
1018  auto [op, arg, sym, map, isByRef] = t;
1019  if (isByRef)
1020  p << "byref ";
1021  if (sym)
1022  p << sym << " ";
1023 
1024  p << op << " -> " << arg;
1025 
1026  if (map != -1)
1027  p << " [map_idx=" << map << "]";
1028  });
1029  p << " : ";
1030  llvm::interleaveComma(types, p);
1031  p << ") ";
1032 
1033  if (needsBarrier)
1034  p << getPrivateNeedsBarrierSpelling() << " ";
1035 }
1036 
1038  StringRef clauseName, ValueRange argsSubrange,
1039  std::optional<MapPrintArgs> mapArgs) {
1040  if (mapArgs)
1041  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1042  mapArgs->types);
1043 }
1044 
1046  StringRef clauseName, ValueRange argsSubrange,
1047  std::optional<PrivatePrintArgs> privateArgs) {
1048  if (privateArgs)
1050  p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1051  privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1052  /*modifier=*/nullptr, privateArgs->needsBarrier);
1053 }
1054 
1055 static void
1056 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1057  ValueRange argsSubrange,
1058  std::optional<ReductionPrintArgs> reductionArgs) {
1059  if (reductionArgs)
1060  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1061  reductionArgs->vars, reductionArgs->types,
1062  reductionArgs->syms, /*mapIndices=*/nullptr,
1063  reductionArgs->byref, reductionArgs->modifier);
1064 }
1065 
1066 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1067  const AllRegionPrintArgs &args) {
1068  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1069  MLIRContext *ctx = op->getContext();
1070 
1071  printBlockArgClause(p, ctx, "has_device_addr",
1072  iface.getHasDeviceAddrBlockArgs(),
1073  args.hasDeviceAddrArgs);
1074  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1075  args.hostEvalArgs);
1076  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1077  args.inReductionArgs);
1078  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1079  args.mapArgs);
1080  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1081  args.privateArgs);
1082  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1083  args.reductionArgs);
1084  printBlockArgClause(p, ctx, "task_reduction",
1085  iface.getTaskReductionBlockArgs(),
1086  args.taskReductionArgs);
1087  printBlockArgClause(p, ctx, "use_device_addr",
1088  iface.getUseDeviceAddrBlockArgs(),
1089  args.useDeviceAddrArgs);
1090  printBlockArgClause(p, ctx, "use_device_ptr",
1091  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1092 
1093  p.printRegion(region, /*printEntryBlockArgs=*/false);
1094 }
1095 
1096 // These parseXyz functions correspond to the custom<Xyz> definitions
1097 // in the .td file(s).
1099  OpAsmPrinter &p, Operation *op, Region &region,
1100  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1101  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1102  ValueRange inReductionVars, TypeRange inReductionTypes,
1103  DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1104  ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1105  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1106  DenseI64ArrayAttr privateMaps) {
1107  AllRegionPrintArgs args;
1108  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1109  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1110  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1111  inReductionByref, inReductionSyms);
1112  args.mapArgs.emplace(mapVars, mapTypes);
1113  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1114  privateNeedsBarrier, privateMaps);
1115  printBlockArgRegion(p, op, region, args);
1116 }
1117 
1119  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1120  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1121  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1122  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1123  AllRegionPrintArgs args;
1124  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1125  inReductionByref, inReductionSyms);
1126  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1127  privateNeedsBarrier,
1128  /*mapIndices=*/nullptr);
1129  printBlockArgRegion(p, op, region, args);
1130 }
1131 
1133  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1134  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1135  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1136  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1137  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1138  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1139  ArrayAttr reductionSyms) {
1140  AllRegionPrintArgs args;
1141  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1142  inReductionByref, inReductionSyms);
1143  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144  privateNeedsBarrier,
1145  /*mapIndices=*/nullptr);
1146  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1147  reductionSyms, reductionMod);
1148  printBlockArgRegion(p, op, region, args);
1149 }
1150 
1151 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1152  ValueRange privateVars, TypeRange privateTypes,
1153  ArrayAttr privateSyms,
1154  UnitAttr privateNeedsBarrier) {
1155  AllRegionPrintArgs args;
1156  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1157  privateNeedsBarrier,
1158  /*mapIndices=*/nullptr);
1159  printBlockArgRegion(p, op, region, args);
1160 }
1161 
1163  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1164  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1165  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1166  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1167  ArrayAttr reductionSyms) {
1168  AllRegionPrintArgs args;
1169  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1170  privateNeedsBarrier,
1171  /*mapIndices=*/nullptr);
1172  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1173  reductionSyms, reductionMod);
1174  printBlockArgRegion(p, op, region, args);
1175 }
1176 
1178  Region &region,
1179  ValueRange taskReductionVars,
1180  TypeRange taskReductionTypes,
1181  DenseBoolArrayAttr taskReductionByref,
1182  ArrayAttr taskReductionSyms) {
1183  AllRegionPrintArgs args;
1184  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1185  taskReductionByref, taskReductionSyms);
1186  printBlockArgRegion(p, op, region, args);
1187 }
1188 
1190  Region &region,
1191  ValueRange useDeviceAddrVars,
1192  TypeRange useDeviceAddrTypes,
1193  ValueRange useDevicePtrVars,
1194  TypeRange useDevicePtrTypes) {
1195  AllRegionPrintArgs args;
1196  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1197  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1198  printBlockArgRegion(p, op, region, args);
1199 }
1200 
1201 /// Verifies Reduction Clause
1202 static LogicalResult
1203 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1204  OperandRange reductionVars,
1205  std::optional<ArrayRef<bool>> reductionByref) {
1206  if (!reductionVars.empty()) {
1207  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1208  return op->emitOpError()
1209  << "expected as many reduction symbol references "
1210  "as reduction variables";
1211  if (reductionByref && reductionByref->size() != reductionVars.size())
1212  return op->emitError() << "expected as many reduction variable by "
1213  "reference attributes as reduction variables";
1214  } else {
1215  if (reductionSyms)
1216  return op->emitOpError() << "unexpected reduction symbol references";
1217  return success();
1218  }
1219 
1220  // TODO: The followings should be done in
1221  // SymbolUserOpInterface::verifySymbolUses.
1222  DenseSet<Value> accumulators;
1223  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1224  Value accum = std::get<0>(args);
1225 
1226  if (!accumulators.insert(accum).second)
1227  return op->emitOpError() << "accumulator variable used more than once";
1228 
1229  Type varType = accum.getType();
1230  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1231  auto decl =
1232  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1233  if (!decl)
1234  return op->emitOpError() << "expected symbol reference " << symbolRef
1235  << " to point to a reduction declaration";
1236 
1237  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1238  return op->emitOpError()
1239  << "expected accumulator (" << varType
1240  << ") to be the same type as reduction declaration ("
1241  << decl.getAccumulatorType() << ")";
1242  }
1243 
1244  return success();
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // Parser, printer and verifier for Copyprivate
1249 //===----------------------------------------------------------------------===//
1250 
1251 /// copyprivate-entry-list ::= copyprivate-entry
1252 /// | copyprivate-entry-list `,` copyprivate-entry
1253 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1254 static ParseResult parseCopyprivate(
1255  OpAsmParser &parser,
1257  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1259  if (failed(parser.parseCommaSeparatedList([&]() {
1260  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1261  parser.parseArrow() ||
1262  parser.parseAttribute(symsVec.emplace_back()) ||
1263  parser.parseColonType(copyprivateTypes.emplace_back()))
1264  return failure();
1265  return success();
1266  })))
1267  return failure();
1268  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1269  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1270  return success();
1271 }
1272 
1273 /// Print Copyprivate clause
1275  OperandRange copyprivateVars,
1276  TypeRange copyprivateTypes,
1277  std::optional<ArrayAttr> copyprivateSyms) {
1278  if (!copyprivateSyms.has_value())
1279  return;
1280  llvm::interleaveComma(
1281  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1282  [&](const auto &args) {
1283  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1284  << std::get<2>(args);
1285  });
1286 }
1287 
1288 /// Verifies CopyPrivate Clause
1289 static LogicalResult
1291  std::optional<ArrayAttr> copyprivateSyms) {
1292  size_t copyprivateSymsSize =
1293  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1294  if (copyprivateSymsSize != copyprivateVars.size())
1295  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1296  << copyprivateVars.size()
1297  << ") and functions (= " << copyprivateSymsSize
1298  << "), both must be equal";
1299  if (!copyprivateSyms.has_value())
1300  return success();
1301 
1302  for (auto copyprivateVarAndSym :
1303  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1304  auto symbolRef =
1305  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1306  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1307  funcOp;
1308  if (mlir::func::FuncOp mlirFuncOp =
1309  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1310  symbolRef))
1311  funcOp = mlirFuncOp;
1312  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1313  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1314  op, symbolRef))
1315  funcOp = llvmFuncOp;
1316 
1317  auto getNumArguments = [&] {
1318  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1319  };
1320 
1321  auto getArgumentType = [&](unsigned i) {
1322  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1323  *funcOp);
1324  };
1325 
1326  if (!funcOp)
1327  return op->emitOpError() << "expected symbol reference " << symbolRef
1328  << " to point to a copy function";
1329 
1330  if (getNumArguments() != 2)
1331  return op->emitOpError()
1332  << "expected copy function " << symbolRef << " to have 2 operands";
1333 
1334  Type argTy = getArgumentType(0);
1335  if (argTy != getArgumentType(1))
1336  return op->emitOpError() << "expected copy function " << symbolRef
1337  << " arguments to have the same type";
1338 
1339  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1340  if (argTy != varType)
1341  return op->emitOpError()
1342  << "expected copy function arguments' type (" << argTy
1343  << ") to be the same as copyprivate variable's type (" << varType
1344  << ")";
1345  }
1346 
1347  return success();
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // Parser, printer and verifier for DependVarList
1352 //===----------------------------------------------------------------------===//
1353 
1354 /// depend-entry-list ::= depend-entry
1355 /// | depend-entry-list `,` depend-entry
1356 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1357 static ParseResult
1360  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1362  if (failed(parser.parseCommaSeparatedList([&]() {
1363  StringRef keyword;
1364  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1365  parser.parseOperand(dependVars.emplace_back()) ||
1366  parser.parseColonType(dependTypes.emplace_back()))
1367  return failure();
1368  if (std::optional<ClauseTaskDepend> keywordDepend =
1369  (symbolizeClauseTaskDepend(keyword)))
1370  kindsVec.emplace_back(
1371  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1372  else
1373  return failure();
1374  return success();
1375  })))
1376  return failure();
1377  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1378  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1379  return success();
1380 }
1381 
1382 /// Print Depend clause
1384  OperandRange dependVars, TypeRange dependTypes,
1385  std::optional<ArrayAttr> dependKinds) {
1386 
1387  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1388  if (i != 0)
1389  p << ", ";
1390  p << stringifyClauseTaskDepend(
1391  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1392  .getValue())
1393  << " -> " << dependVars[i] << " : " << dependTypes[i];
1394  }
1395 }
1396 
1397 /// Verifies Depend clause
1398 static LogicalResult verifyDependVarList(Operation *op,
1399  std::optional<ArrayAttr> dependKinds,
1400  OperandRange dependVars) {
1401  if (!dependVars.empty()) {
1402  if (!dependKinds || dependKinds->size() != dependVars.size())
1403  return op->emitOpError() << "expected as many depend values"
1404  " as depend variables";
1405  } else {
1406  if (dependKinds && !dependKinds->empty())
1407  return op->emitOpError() << "unexpected depend values";
1408  return success();
1409  }
1410 
1411  return success();
1412 }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1416 //===----------------------------------------------------------------------===//
1417 
1418 /// Parses a Synchronization Hint clause. The value of hint is an integer
1419 /// which is a combination of different hints from `omp_sync_hint_t`.
1420 ///
1421 /// hint-clause = `hint` `(` hint-value `)`
1422 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1423  IntegerAttr &hintAttr) {
1424  StringRef hintKeyword;
1425  int64_t hint = 0;
1426  if (succeeded(parser.parseOptionalKeyword("none"))) {
1427  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1428  return success();
1429  }
1430  auto parseKeyword = [&]() -> ParseResult {
1431  if (failed(parser.parseKeyword(&hintKeyword)))
1432  return failure();
1433  if (hintKeyword == "uncontended")
1434  hint |= 1;
1435  else if (hintKeyword == "contended")
1436  hint |= 2;
1437  else if (hintKeyword == "nonspeculative")
1438  hint |= 4;
1439  else if (hintKeyword == "speculative")
1440  hint |= 8;
1441  else
1442  return parser.emitError(parser.getCurrentLocation())
1443  << hintKeyword << " is not a valid hint";
1444  return success();
1445  };
1446  if (parser.parseCommaSeparatedList(parseKeyword))
1447  return failure();
1448  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1449  return success();
1450 }
1451 
1452 /// Prints a Synchronization Hint clause
1454  IntegerAttr hintAttr) {
1455  int64_t hint = hintAttr.getInt();
1456 
1457  if (hint == 0) {
1458  p << "none";
1459  return;
1460  }
1461 
1462  // Helper function to get n-th bit from the right end of `value`
1463  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1464 
1465  bool uncontended = bitn(hint, 0);
1466  bool contended = bitn(hint, 1);
1467  bool nonspeculative = bitn(hint, 2);
1468  bool speculative = bitn(hint, 3);
1469 
1470  SmallVector<StringRef> hints;
1471  if (uncontended)
1472  hints.push_back("uncontended");
1473  if (contended)
1474  hints.push_back("contended");
1475  if (nonspeculative)
1476  hints.push_back("nonspeculative");
1477  if (speculative)
1478  hints.push_back("speculative");
1479 
1480  llvm::interleaveComma(hints, p);
1481 }
1482 
1483 /// Verifies a synchronization hint clause
1484 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1485 
1486  // Helper function to get n-th bit from the right end of `value`
1487  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1488 
1489  bool uncontended = bitn(hint, 0);
1490  bool contended = bitn(hint, 1);
1491  bool nonspeculative = bitn(hint, 2);
1492  bool speculative = bitn(hint, 3);
1493 
1494  if (uncontended && contended)
1495  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1496  "omp_sync_hint_contended cannot be combined";
1497  if (nonspeculative && speculative)
1498  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1499  "omp_sync_hint_speculative cannot be combined.";
1500  return success();
1501 }
1502 
1503 //===----------------------------------------------------------------------===//
1504 // Parser, printer and verifier for Target
1505 //===----------------------------------------------------------------------===//
1506 
1507 // Helper function to get bitwise AND of `value` and 'flag'
1508 uint64_t mapTypeToBitFlag(uint64_t value,
1509  llvm::omp::OpenMPOffloadMappingFlags flag) {
1510  return value & llvm::to_underlying(flag);
1511 }
1512 
1513 /// Parses a map_entries map type from a string format back into its numeric
1514 /// value.
1515 ///
1516 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1517 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1518 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1519  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1520  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1521 
1522  // This simply verifies the correct keyword is read in, the
1523  // keyword itself is stored inside of the operation
1524  auto parseTypeAndMod = [&]() -> ParseResult {
1525  StringRef mapTypeMod;
1526  if (parser.parseKeyword(&mapTypeMod))
1527  return failure();
1528 
1529  if (mapTypeMod == "always")
1530  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1531 
1532  if (mapTypeMod == "implicit")
1533  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1534 
1535  if (mapTypeMod == "ompx_hold")
1536  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1537 
1538  if (mapTypeMod == "close")
1539  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1540 
1541  if (mapTypeMod == "present")
1542  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1543 
1544  if (mapTypeMod == "to")
1545  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1546 
1547  if (mapTypeMod == "from")
1548  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1549 
1550  if (mapTypeMod == "tofrom")
1551  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1552  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1553 
1554  if (mapTypeMod == "delete")
1555  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1556 
1557  if (mapTypeMod == "return_param")
1558  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1559 
1560  return success();
1561  };
1562 
1563  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1564  return failure();
1565 
1566  mapType = parser.getBuilder().getIntegerAttr(
1567  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1568  llvm::to_underlying(mapTypeBits));
1569 
1570  return success();
1571 }
1572 
1573 /// Prints a map_entries map type from its numeric value out into its string
1574 /// format.
1576  IntegerAttr mapType) {
1577  uint64_t mapTypeBits = mapType.getUInt();
1578 
1579  bool emitAllocRelease = true;
1581 
1582  // handling of always, close, present placed at the beginning of the string
1583  // to aid readability
1584  if (mapTypeToBitFlag(mapTypeBits,
1585  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1586  mapTypeStrs.push_back("always");
1587  if (mapTypeToBitFlag(mapTypeBits,
1588  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1589  mapTypeStrs.push_back("implicit");
1590  if (mapTypeToBitFlag(mapTypeBits,
1591  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1592  mapTypeStrs.push_back("ompx_hold");
1593  if (mapTypeToBitFlag(mapTypeBits,
1594  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1595  mapTypeStrs.push_back("close");
1596  if (mapTypeToBitFlag(mapTypeBits,
1597  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1598  mapTypeStrs.push_back("present");
1599 
1600  // special handling of to/from/tofrom/delete and release/alloc, release +
1601  // alloc are the abscense of one of the other flags, whereas tofrom requires
1602  // both the to and from flag to be set.
1603  bool to = mapTypeToBitFlag(mapTypeBits,
1604  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1605  bool from = mapTypeToBitFlag(
1606  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1607  if (to && from) {
1608  emitAllocRelease = false;
1609  mapTypeStrs.push_back("tofrom");
1610  } else if (from) {
1611  emitAllocRelease = false;
1612  mapTypeStrs.push_back("from");
1613  } else if (to) {
1614  emitAllocRelease = false;
1615  mapTypeStrs.push_back("to");
1616  }
1617  if (mapTypeToBitFlag(mapTypeBits,
1618  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1619  emitAllocRelease = false;
1620  mapTypeStrs.push_back("delete");
1621  }
1622  if (mapTypeToBitFlag(
1623  mapTypeBits,
1624  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1625  emitAllocRelease = false;
1626  mapTypeStrs.push_back("return_param");
1627  }
1628  if (emitAllocRelease)
1629  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1630 
1631  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1632  p << mapTypeStrs[i];
1633  if (i + 1 < mapTypeStrs.size()) {
1634  p << ", ";
1635  }
1636  }
1637 }
1638 
1639 static ParseResult parseMembersIndex(OpAsmParser &parser,
1640  ArrayAttr &membersIdx) {
1641  SmallVector<Attribute> values, memberIdxs;
1642 
1643  auto parseIndices = [&]() -> ParseResult {
1644  int64_t value;
1645  if (parser.parseInteger(value))
1646  return failure();
1647  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1648  APInt(64, value, /*isSigned=*/false)));
1649  return success();
1650  };
1651 
1652  do {
1653  if (failed(parser.parseLSquare()))
1654  return failure();
1655 
1656  if (parser.parseCommaSeparatedList(parseIndices))
1657  return failure();
1658 
1659  if (failed(parser.parseRSquare()))
1660  return failure();
1661 
1662  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1663  values.clear();
1664  } while (succeeded(parser.parseOptionalComma()));
1665 
1666  if (!memberIdxs.empty())
1667  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1668 
1669  return success();
1670 }
1671 
1672 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1673  ArrayAttr membersIdx) {
1674  if (!membersIdx)
1675  return;
1676 
1677  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1678  p << "[";
1679  auto memberIdx = cast<ArrayAttr>(v);
1680  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1681  p << cast<IntegerAttr>(v2).getInt();
1682  });
1683  p << "]";
1684  });
1685 }
1686 
1688  VariableCaptureKindAttr mapCaptureType) {
1689  std::string typeCapStr;
1690  llvm::raw_string_ostream typeCap(typeCapStr);
1691  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1692  typeCap << "ByRef";
1693  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1694  typeCap << "ByCopy";
1695  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1696  typeCap << "VLAType";
1697  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1698  typeCap << "This";
1699  p << typeCapStr;
1700 }
1701 
1702 static ParseResult parseCaptureType(OpAsmParser &parser,
1703  VariableCaptureKindAttr &mapCaptureType) {
1704  StringRef mapCaptureKey;
1705  if (parser.parseKeyword(&mapCaptureKey))
1706  return failure();
1707 
1708  if (mapCaptureKey == "This")
1709  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1710  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1711  if (mapCaptureKey == "ByRef")
1712  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1713  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1714  if (mapCaptureKey == "ByCopy")
1715  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1716  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1717  if (mapCaptureKey == "VLAType")
1718  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1719  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1720 
1721  return success();
1722 }
1723 
1724 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1727 
1728  for (auto mapOp : mapVars) {
1729  if (!mapOp.getDefiningOp())
1730  return emitError(op->getLoc(), "missing map operation");
1731 
1732  if (auto mapInfoOp =
1733  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1734  uint64_t mapTypeBits = mapInfoOp.getMapType();
1735 
1736  bool to = mapTypeToBitFlag(
1737  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1738  bool from = mapTypeToBitFlag(
1739  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1740  bool del = mapTypeToBitFlag(
1741  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1742 
1743  bool always = mapTypeToBitFlag(
1744  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1745  bool close = mapTypeToBitFlag(
1746  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1747  bool implicit = mapTypeToBitFlag(
1748  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1749 
1750  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1751  return emitError(op->getLoc(),
1752  "to, from, tofrom and alloc map types are permitted");
1753 
1754  if (isa<TargetEnterDataOp>(op) && (from || del))
1755  return emitError(op->getLoc(), "to and alloc map types are permitted");
1756 
1757  if (isa<TargetExitDataOp>(op) && to)
1758  return emitError(op->getLoc(),
1759  "from, release and delete map types are permitted");
1760 
1761  if (isa<TargetUpdateOp>(op)) {
1762  if (del) {
1763  return emitError(op->getLoc(),
1764  "at least one of to or from map types must be "
1765  "specified, other map types are not permitted");
1766  }
1767 
1768  if (!to && !from) {
1769  return emitError(op->getLoc(),
1770  "at least one of to or from map types must be "
1771  "specified, other map types are not permitted");
1772  }
1773 
1774  auto updateVar = mapInfoOp.getVarPtr();
1775 
1776  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1777  (from && updateToVars.contains(updateVar))) {
1778  return emitError(
1779  op->getLoc(),
1780  "either to or from map types can be specified, not both");
1781  }
1782 
1783  if (always || close || implicit) {
1784  return emitError(
1785  op->getLoc(),
1786  "present, mapper and iterator map type modifiers are permitted");
1787  }
1788 
1789  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1790  }
1791  } else if (!isa<DeclareMapperInfoOp>(op)) {
1792  return emitError(op->getLoc(),
1793  "map argument is not a map entry operation");
1794  }
1795  }
1796 
1797  return success();
1798 }
1799 
1800 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1801  std::optional<DenseI64ArrayAttr> privateMapIndices =
1802  targetOp.getPrivateMapsAttr();
1803 
1804  // None of the private operands are mapped.
1805  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1806  return success();
1807 
1808  OperandRange privateVars = targetOp.getPrivateVars();
1809 
1810  if (privateMapIndices.value().size() !=
1811  static_cast<int64_t>(privateVars.size()))
1812  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1813  "`private_maps` attribute mismatch");
1814 
1815  return success();
1816 }
1817 
1818 //===----------------------------------------------------------------------===//
1819 // MapInfoOp
1820 //===----------------------------------------------------------------------===//
1821 
1822 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1823  StringRef clauseName,
1824  OperandRange vars) {
1825  for (Value var : vars)
1826  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1827  return op->emitOpError()
1828  << "'" << clauseName
1829  << "' arguments must be defined by 'omp.map.info' ops";
1830  return success();
1831 }
1832 
1833 LogicalResult MapInfoOp::verify() {
1834  if (getMapperId() &&
1835  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1836  *this, getMapperIdAttr())) {
1837  return emitError("invalid mapper id");
1838  }
1839 
1840  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1841  return failure();
1842 
1843  return success();
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // TargetDataOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1851  const TargetDataOperands &clauses) {
1852  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1853  clauses.mapVars, clauses.useDeviceAddrVars,
1854  clauses.useDevicePtrVars);
1855 }
1856 
1857 LogicalResult TargetDataOp::verify() {
1858  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1859  getUseDeviceAddrVars().empty()) {
1860  return ::emitError(this->getLoc(),
1861  "At least one of map, use_device_ptr_vars, or "
1862  "use_device_addr_vars operand must be present");
1863  }
1864 
1865  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1866  getUseDevicePtrVars())))
1867  return failure();
1868 
1869  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1870  getUseDeviceAddrVars())))
1871  return failure();
1872 
1873  return verifyMapClause(*this, getMapVars());
1874 }
1875 
1876 //===----------------------------------------------------------------------===//
1877 // TargetEnterDataOp
1878 //===----------------------------------------------------------------------===//
1879 
1880 void TargetEnterDataOp::build(
1881  OpBuilder &builder, OperationState &state,
1882  const TargetEnterExitUpdateDataOperands &clauses) {
1883  MLIRContext *ctx = builder.getContext();
1884  TargetEnterDataOp::build(builder, state,
1885  makeArrayAttr(ctx, clauses.dependKinds),
1886  clauses.dependVars, clauses.device, clauses.ifExpr,
1887  clauses.mapVars, clauses.nowait);
1888 }
1889 
1890 LogicalResult TargetEnterDataOp::verify() {
1891  LogicalResult verifyDependVars =
1892  verifyDependVarList(*this, getDependKinds(), getDependVars());
1893  return failed(verifyDependVars) ? verifyDependVars
1894  : verifyMapClause(*this, getMapVars());
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // TargetExitDataOp
1899 //===----------------------------------------------------------------------===//
1900 
1901 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1902  const TargetEnterExitUpdateDataOperands &clauses) {
1903  MLIRContext *ctx = builder.getContext();
1904  TargetExitDataOp::build(builder, state,
1905  makeArrayAttr(ctx, clauses.dependKinds),
1906  clauses.dependVars, clauses.device, clauses.ifExpr,
1907  clauses.mapVars, clauses.nowait);
1908 }
1909 
1910 LogicalResult TargetExitDataOp::verify() {
1911  LogicalResult verifyDependVars =
1912  verifyDependVarList(*this, getDependKinds(), getDependVars());
1913  return failed(verifyDependVars) ? verifyDependVars
1914  : verifyMapClause(*this, getMapVars());
1915 }
1916 
1917 //===----------------------------------------------------------------------===//
1918 // TargetUpdateOp
1919 //===----------------------------------------------------------------------===//
1920 
1921 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1922  const TargetEnterExitUpdateDataOperands &clauses) {
1923  MLIRContext *ctx = builder.getContext();
1924  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1925  clauses.dependVars, clauses.device, clauses.ifExpr,
1926  clauses.mapVars, clauses.nowait);
1927 }
1928 
1929 LogicalResult TargetUpdateOp::verify() {
1930  LogicalResult verifyDependVars =
1931  verifyDependVarList(*this, getDependKinds(), getDependVars());
1932  return failed(verifyDependVars) ? verifyDependVars
1933  : verifyMapClause(*this, getMapVars());
1934 }
1935 
1936 //===----------------------------------------------------------------------===//
1937 // TargetOp
1938 //===----------------------------------------------------------------------===//
1939 
1940 void TargetOp::build(OpBuilder &builder, OperationState &state,
1941  const TargetOperands &clauses) {
1942  MLIRContext *ctx = builder.getContext();
1943  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1944  // inReductionByref, inReductionSyms.
1945  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1946  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1947  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1948  clauses.hostEvalVars, clauses.ifExpr,
1949  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1950  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1951  clauses.mapVars, clauses.nowait, clauses.privateVars,
1952  makeArrayAttr(ctx, clauses.privateSyms),
1953  clauses.privateNeedsBarrier, clauses.threadLimit,
1954  /*private_maps=*/nullptr);
1955 }
1956 
1957 LogicalResult TargetOp::verify() {
1958  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1959  return failure();
1960 
1961  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1962  getHasDeviceAddrVars())))
1963  return failure();
1964 
1965  if (failed(verifyMapClause(*this, getMapVars())))
1966  return failure();
1967 
1968  return verifyPrivateVarsMapping(*this);
1969 }
1970 
1971 LogicalResult TargetOp::verifyRegions() {
1972  auto teamsOps = getOps<TeamsOp>();
1973  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1974  return emitError("target containing multiple 'omp.teams' nested ops");
1975 
1976  // Check that host_eval values are only used in legal ways.
1977  Operation *capturedOp = getInnermostCapturedOmpOp();
1978  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1979  for (Value hostEvalArg :
1980  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1981  for (Operation *user : hostEvalArg.getUsers()) {
1982  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1983  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1984  teamsOp.getNumTeamsUpper(),
1985  teamsOp.getThreadLimit()},
1986  hostEvalArg))
1987  continue;
1988 
1989  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1990  "and 'thread_limit' in 'omp.teams'";
1991  }
1992  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1994  parallelOp->isAncestor(capturedOp) &&
1995  hostEvalArg == parallelOp.getNumThreads())
1996  continue;
1997 
1998  return emitOpError()
1999  << "host_eval argument only legal as 'num_threads' in "
2000  "'omp.parallel' when representing target SPMD";
2001  }
2002  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2004  loopNestOp.getOperation() == capturedOp &&
2005  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2006  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2007  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2008  continue;
2009 
2010  return emitOpError() << "host_eval argument only legal as loop bounds "
2011  "and steps in 'omp.loop_nest' when trip count "
2012  "must be evaluated in the host";
2013  }
2014 
2015  return emitOpError() << "host_eval argument illegal use in '"
2016  << user->getName() << "' operation";
2017  }
2018  }
2019  return success();
2020 }
2021 
2022 static Operation *
2023 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2024  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2025  assert(rootOp && "expected valid operation");
2026 
2027  Dialect *ompDialect = rootOp->getDialect();
2028  Operation *capturedOp = nullptr;
2029  DominanceInfo domInfo;
2030 
2031  // Process in pre-order to check operations from outermost to innermost,
2032  // ensuring we only enter the region of an operation if it meets the criteria
2033  // for being captured. We stop the exploration of nested operations as soon as
2034  // we process a region holding no operations to be captured.
2035  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2036  if (op == rootOp)
2037  return WalkResult::advance();
2038 
2039  // Ignore operations of other dialects or omp operations with no regions,
2040  // because these will only be checked if they are siblings of an omp
2041  // operation that can potentially be captured.
2042  bool isOmpDialect = op->getDialect() == ompDialect;
2043  bool hasRegions = op->getNumRegions() > 0;
2044  if (!isOmpDialect || !hasRegions)
2045  return WalkResult::skip();
2046 
2047  // This operation cannot be captured if it can be executed more than once
2048  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2049  // be executed before all exits of the region (i.e. it doesn't dominate all
2050  // blocks with no successors reachable from the entry block).
2051  if (checkSingleMandatoryExec) {
2052  Region *parentRegion = op->getParentRegion();
2053  Block *parentBlock = op->getBlock();
2054 
2055  for (Block *successor : parentBlock->getSuccessors())
2056  if (successor->isReachable(parentBlock))
2057  return WalkResult::interrupt();
2058 
2059  for (Block &block : *parentRegion)
2060  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2061  !domInfo.dominates(parentBlock, &block))
2062  return WalkResult::interrupt();
2063  }
2064 
2065  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2066  // into nested operations.
2067  for (Operation &sibling : op->getParentRegion()->getOps())
2068  if (&sibling != op && !siblingAllowedFn(&sibling))
2069  return WalkResult::interrupt();
2070 
2071  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2072  // Otherwise, process the contents of this operation.
2073  capturedOp = op;
2074  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2075  : WalkResult::advance();
2076  });
2077 
2078  return capturedOp;
2079 }
2080 
2081 Operation *TargetOp::getInnermostCapturedOmpOp() {
2082  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2083 
2084  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2085  // effects, but don't include a memory write effect.
2086  return findCapturedOmpOp(
2087  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2088  if (!sibling)
2089  return false;
2090 
2091  if (ompDialect == sibling->getDialect())
2092  return sibling->hasTrait<OpTrait::IsTerminator>();
2093 
2094  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2096  effects;
2097  memOp.getEffects(effects);
2098  return !llvm::any_of(
2099  effects, [&](MemoryEffects::EffectInstance &effect) {
2100  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2101  isa<SideEffects::AutomaticAllocationScopeResource>(
2102  effect.getResource());
2103  });
2104  }
2105  return true;
2106  });
2107 }
2108 
2109 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2110  // A non-null captured op is only valid if it resides inside of a TargetOp
2111  // and is the result of calling getInnermostCapturedOmpOp() on it.
2112  TargetOp targetOp =
2113  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2114  assert((!capturedOp ||
2115  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2116  "unexpected captured op");
2117 
2118  // If it's not capturing a loop, it's a default target region.
2119  if (!isa_and_present<LoopNestOp>(capturedOp))
2120  return TargetRegionFlags::generic;
2121 
2122  // Get the innermost non-simd loop wrapper.
2123  SmallVector<LoopWrapperInterface> loopWrappers;
2124  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2125  assert(!loopWrappers.empty());
2126 
2127  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2128  if (isa<SimdOp>(innermostWrapper))
2129  innermostWrapper = std::next(innermostWrapper);
2130 
2131  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2132  if (numWrappers != 1 && numWrappers != 2)
2133  return TargetRegionFlags::generic;
2134 
2135  // Detect target-teams-distribute-parallel-wsloop[-simd].
2136  if (numWrappers == 2) {
2137  if (!isa<WsloopOp>(innermostWrapper))
2138  return TargetRegionFlags::generic;
2139 
2140  innermostWrapper = std::next(innermostWrapper);
2141  if (!isa<DistributeOp>(innermostWrapper))
2142  return TargetRegionFlags::generic;
2143 
2144  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2145  if (!isa_and_present<ParallelOp>(parallelOp))
2146  return TargetRegionFlags::generic;
2147 
2148  Operation *teamsOp = parallelOp->getParentOp();
2149  if (!isa_and_present<TeamsOp>(teamsOp))
2150  return TargetRegionFlags::generic;
2151 
2152  if (teamsOp->getParentOp() == targetOp.getOperation())
2153  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2154  }
2155  // Detect target-teams-distribute[-simd] and target-teams-loop.
2156  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2157  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2158  if (!isa_and_present<TeamsOp>(teamsOp))
2159  return TargetRegionFlags::generic;
2160 
2161  if (teamsOp->getParentOp() != targetOp.getOperation())
2162  return TargetRegionFlags::generic;
2163 
2164  if (isa<LoopOp>(innermostWrapper))
2165  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2166 
2167  // Find single immediately nested captured omp.parallel and add spmd flag
2168  // (generic-spmd case).
2169  //
2170  // TODO: This shouldn't have to be done here, as it is too easy to break.
2171  // The openmp-opt pass should be updated to be able to promote kernels like
2172  // this from "Generic" to "Generic-SPMD". However, the use of the
2173  // `kmpc_distribute_static_loop` family of functions produced by the
2174  // OMPIRBuilder for these kernels prevents that from working.
2175  Dialect *ompDialect = targetOp->getDialect();
2176  Operation *nestedCapture = findCapturedOmpOp(
2177  capturedOp, /*checkSingleMandatoryExec=*/false,
2178  [&](Operation *sibling) {
2179  return sibling && (ompDialect != sibling->getDialect() ||
2180  sibling->hasTrait<OpTrait::IsTerminator>());
2181  });
2182 
2183  TargetRegionFlags result =
2184  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2185 
2186  if (!nestedCapture)
2187  return result;
2188 
2189  while (nestedCapture->getParentOp() != capturedOp)
2190  nestedCapture = nestedCapture->getParentOp();
2191 
2192  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2193  : result;
2194  }
2195  // Detect target-parallel-wsloop[-simd].
2196  else if (isa<WsloopOp>(innermostWrapper)) {
2197  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2198  if (!isa_and_present<ParallelOp>(parallelOp))
2199  return TargetRegionFlags::generic;
2200 
2201  if (parallelOp->getParentOp() == targetOp.getOperation())
2202  return TargetRegionFlags::spmd;
2203  }
2204 
2205  return TargetRegionFlags::generic;
2206 }
2207 
2208 //===----------------------------------------------------------------------===//
2209 // ParallelOp
2210 //===----------------------------------------------------------------------===//
2211 
2212 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2213  ArrayRef<NamedAttribute> attributes) {
2214  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2215  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2216  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2217  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2218  /*proc_bind_kind=*/nullptr,
2219  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2220  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2221  state.addAttributes(attributes);
2222 }
2223 
2224 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2225  const ParallelOperands &clauses) {
2226  MLIRContext *ctx = builder.getContext();
2227  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2228  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2229  makeArrayAttr(ctx, clauses.privateSyms),
2230  clauses.privateNeedsBarrier, clauses.procBindKind,
2231  clauses.reductionMod, clauses.reductionVars,
2232  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2233  makeArrayAttr(ctx, clauses.reductionSyms));
2234 }
2235 
2236 template <typename OpType>
2237 static LogicalResult verifyPrivateVarList(OpType &op) {
2238  auto privateVars = op.getPrivateVars();
2239  auto privateSyms = op.getPrivateSymsAttr();
2240 
2241  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2242  return success();
2243 
2244  auto numPrivateVars = privateVars.size();
2245  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2246 
2247  if (numPrivateVars != numPrivateSyms)
2248  return op.emitError() << "inconsistent number of private variables and "
2249  "privatizer op symbols, private vars: "
2250  << numPrivateVars
2251  << " vs. privatizer op symbols: " << numPrivateSyms;
2252 
2253  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2254  Type varType = std::get<0>(privateVarInfo).getType();
2255  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2256  PrivateClauseOp privatizerOp =
2257  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2258 
2259  if (privatizerOp == nullptr)
2260  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2261  << privateSym << "'";
2262 
2263  Type privatizerType = privatizerOp.getArgType();
2264 
2265  if (privatizerType && (varType != privatizerType))
2266  return op.emitError()
2267  << "type mismatch between a "
2268  << (privatizerOp.getDataSharingType() ==
2269  DataSharingClauseType::Private
2270  ? "private"
2271  : "firstprivate")
2272  << " variable and its privatizer op, var type: " << varType
2273  << " vs. privatizer op type: " << privatizerType;
2274  }
2275 
2276  return success();
2277 }
2278 
2279 LogicalResult ParallelOp::verify() {
2280  if (getAllocateVars().size() != getAllocatorVars().size())
2281  return emitError(
2282  "expected equal sizes for allocate and allocator variables");
2283 
2284  if (failed(verifyPrivateVarList(*this)))
2285  return failure();
2286 
2287  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2288  getReductionByref());
2289 }
2290 
2291 LogicalResult ParallelOp::verifyRegions() {
2292  auto distChildOps = getOps<DistributeOp>();
2293  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2294  if (numDistChildOps > 1)
2295  return emitError()
2296  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2297 
2298  if (numDistChildOps == 1) {
2299  if (!isComposite())
2300  return emitError()
2301  << "'omp.composite' attribute missing from composite operation";
2302 
2303  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2304  Operation &distributeOp = **distChildOps.begin();
2305  for (Operation &childOp : getOps()) {
2306  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2307  continue;
2308 
2309  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2310  return emitError() << "unexpected OpenMP operation inside of composite "
2311  "'omp.parallel': "
2312  << childOp.getName();
2313  }
2314  } else if (isComposite()) {
2315  return emitError()
2316  << "'omp.composite' attribute present in non-composite operation";
2317  }
2318  return success();
2319 }
2320 
2321 //===----------------------------------------------------------------------===//
2322 // TeamsOp
2323 //===----------------------------------------------------------------------===//
2324 
2326  while ((op = op->getParentOp()))
2327  if (isa<OpenMPDialect>(op->getDialect()))
2328  return false;
2329  return true;
2330 }
2331 
2332 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2333  const TeamsOperands &clauses) {
2334  MLIRContext *ctx = builder.getContext();
2335  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2336  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2337  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2338  /*private_vars=*/{}, /*private_syms=*/nullptr,
2339  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2340  clauses.reductionVars,
2341  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2342  makeArrayAttr(ctx, clauses.reductionSyms),
2343  clauses.threadLimit);
2344 }
2345 
2346 LogicalResult TeamsOp::verify() {
2347  // Check parent region
2348  // TODO If nested inside of a target region, also check that it does not
2349  // contain any statements, declarations or directives other than this
2350  // omp.teams construct. The issue is how to support the initialization of
2351  // this operation's own arguments (allow SSA values across omp.target?).
2352  Operation *op = getOperation();
2353  if (!isa<TargetOp>(op->getParentOp()) &&
2355  return emitError("expected to be nested inside of omp.target or not nested "
2356  "in any OpenMP dialect operations");
2357 
2358  // Check for num_teams clause restrictions
2359  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2360  auto numTeamsUpperBound = getNumTeamsUpper();
2361  if (!numTeamsUpperBound)
2362  return emitError("expected num_teams upper bound to be defined if the "
2363  "lower bound is defined");
2364  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2365  return emitError(
2366  "expected num_teams upper bound and lower bound to be the same type");
2367  }
2368 
2369  // Check for allocate clause restrictions
2370  if (getAllocateVars().size() != getAllocatorVars().size())
2371  return emitError(
2372  "expected equal sizes for allocate and allocator variables");
2373 
2374  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2375  getReductionByref());
2376 }
2377 
2378 //===----------------------------------------------------------------------===//
2379 // SectionOp
2380 //===----------------------------------------------------------------------===//
2381 
2382 OperandRange SectionOp::getPrivateVars() {
2383  return getParentOp().getPrivateVars();
2384 }
2385 
2386 OperandRange SectionOp::getReductionVars() {
2387  return getParentOp().getReductionVars();
2388 }
2389 
2390 //===----------------------------------------------------------------------===//
2391 // SectionsOp
2392 //===----------------------------------------------------------------------===//
2393 
2394 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2395  const SectionsOperands &clauses) {
2396  MLIRContext *ctx = builder.getContext();
2397  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2398  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2399  clauses.nowait, /*private_vars=*/{},
2400  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2401  clauses.reductionMod, clauses.reductionVars,
2402  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2403  makeArrayAttr(ctx, clauses.reductionSyms));
2404 }
2405 
2406 LogicalResult SectionsOp::verify() {
2407  if (getAllocateVars().size() != getAllocatorVars().size())
2408  return emitError(
2409  "expected equal sizes for allocate and allocator variables");
2410 
2411  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2412  getReductionByref());
2413 }
2414 
2415 LogicalResult SectionsOp::verifyRegions() {
2416  for (auto &inst : *getRegion().begin()) {
2417  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2418  return emitOpError()
2419  << "expected omp.section op or terminator op inside region";
2420  }
2421  }
2422 
2423  return success();
2424 }
2425 
2426 //===----------------------------------------------------------------------===//
2427 // SingleOp
2428 //===----------------------------------------------------------------------===//
2429 
2430 void SingleOp::build(OpBuilder &builder, OperationState &state,
2431  const SingleOperands &clauses) {
2432  MLIRContext *ctx = builder.getContext();
2433  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2434  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2435  clauses.copyprivateVars,
2436  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2437  /*private_vars=*/{}, /*private_syms=*/nullptr,
2438  /*private_needs_barrier=*/nullptr);
2439 }
2440 
2441 LogicalResult SingleOp::verify() {
2442  // Check for allocate clause restrictions
2443  if (getAllocateVars().size() != getAllocatorVars().size())
2444  return emitError(
2445  "expected equal sizes for allocate and allocator variables");
2446 
2447  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2448  getCopyprivateSyms());
2449 }
2450 
2451 //===----------------------------------------------------------------------===//
2452 // WorkshareOp
2453 //===----------------------------------------------------------------------===//
2454 
2455 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2456  const WorkshareOperands &clauses) {
2457  WorkshareOp::build(builder, state, clauses.nowait);
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // WorkshareLoopWrapperOp
2462 //===----------------------------------------------------------------------===//
2463 
2464 LogicalResult WorkshareLoopWrapperOp::verify() {
2465  if (!(*this)->getParentOfType<WorkshareOp>())
2466  return emitOpError() << "must be nested in an omp.workshare";
2467  return success();
2468 }
2469 
2470 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2471  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2472  getNestedWrapper())
2473  return emitOpError() << "expected to be a standalone loop wrapper";
2474 
2475  return success();
2476 }
2477 
2478 //===----------------------------------------------------------------------===//
2479 // LoopWrapperInterface
2480 //===----------------------------------------------------------------------===//
2481 
2482 LogicalResult LoopWrapperInterface::verifyImpl() {
2483  Operation *op = this->getOperation();
2484  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2486  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2487  "and `SingleBlock` traits";
2488 
2489  if (op->getNumRegions() != 1)
2490  return emitOpError() << "loop wrapper does not contain exactly one region";
2491 
2492  Region &region = op->getRegion(0);
2493  if (range_size(region.getOps()) != 1)
2494  return emitOpError()
2495  << "loop wrapper does not contain exactly one nested op";
2496 
2497  Operation &firstOp = *region.op_begin();
2498  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2499  return emitOpError() << "nested in loop wrapper is not another loop "
2500  "wrapper or `omp.loop_nest`";
2501 
2502  return success();
2503 }
2504 
2505 //===----------------------------------------------------------------------===//
2506 // LoopOp
2507 //===----------------------------------------------------------------------===//
2508 
2509 void LoopOp::build(OpBuilder &builder, OperationState &state,
2510  const LoopOperands &clauses) {
2511  MLIRContext *ctx = builder.getContext();
2512 
2513  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2514  makeArrayAttr(ctx, clauses.privateSyms),
2515  clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2516  clauses.reductionMod, clauses.reductionVars,
2517  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2518  makeArrayAttr(ctx, clauses.reductionSyms));
2519 }
2520 
2521 LogicalResult LoopOp::verify() {
2522  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2523  getReductionByref());
2524 }
2525 
2526 LogicalResult LoopOp::verifyRegions() {
2527  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2528  getNestedWrapper())
2529  return emitOpError() << "expected to be a standalone loop wrapper";
2530 
2531  return success();
2532 }
2533 
2534 //===----------------------------------------------------------------------===//
2535 // WsloopOp
2536 //===----------------------------------------------------------------------===//
2537 
2538 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2539  ArrayRef<NamedAttribute> attributes) {
2540  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2541  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2542  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2543  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2544  /*private_needs_barrier=*/false,
2545  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2546  /*reduction_byref=*/nullptr,
2547  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2548  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2549  /*schedule_simd=*/false);
2550  state.addAttributes(attributes);
2551 }
2552 
2553 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2554  const WsloopOperands &clauses) {
2555  MLIRContext *ctx = builder.getContext();
2556  // TODO: Store clauses in op: allocateVars, allocatorVars
2557  WsloopOp::build(
2558  builder, state,
2559  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2560  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2561  clauses.ordered, clauses.privateVars,
2562  makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2563  clauses.reductionMod, clauses.reductionVars,
2564  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2565  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2566  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2567 }
2568 
2569 LogicalResult WsloopOp::verify() {
2570  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2571  getReductionByref());
2572 }
2573 
2574 LogicalResult WsloopOp::verifyRegions() {
2575  bool isCompositeChildLeaf =
2576  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2577 
2578  if (LoopWrapperInterface nested = getNestedWrapper()) {
2579  if (!isComposite())
2580  return emitError()
2581  << "'omp.composite' attribute missing from composite wrapper";
2582 
2583  // Check for the allowed leaf constructs that may appear in a composite
2584  // construct directly after DO/FOR.
2585  if (!isa<SimdOp>(nested))
2586  return emitError() << "only supported nested wrapper is 'omp.simd'";
2587 
2588  } else if (isComposite() && !isCompositeChildLeaf) {
2589  return emitError()
2590  << "'omp.composite' attribute present in non-composite wrapper";
2591  } else if (!isComposite() && isCompositeChildLeaf) {
2592  return emitError()
2593  << "'omp.composite' attribute missing from composite wrapper";
2594  }
2595 
2596  return success();
2597 }
2598 
2599 //===----------------------------------------------------------------------===//
2600 // Simd construct [2.9.3.1]
2601 //===----------------------------------------------------------------------===//
2602 
2603 void SimdOp::build(OpBuilder &builder, OperationState &state,
2604  const SimdOperands &clauses) {
2605  MLIRContext *ctx = builder.getContext();
2606  // TODO Store clauses in op: linearVars, linearStepVars
2607  SimdOp::build(builder, state, clauses.alignedVars,
2608  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2609  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2610  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2611  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2612  clauses.privateNeedsBarrier, clauses.reductionMod,
2613  clauses.reductionVars,
2614  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2615  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2616  clauses.simdlen);
2617 }
2618 
2619 LogicalResult SimdOp::verify() {
2620  if (getSimdlen().has_value() && getSafelen().has_value() &&
2621  getSimdlen().value() > getSafelen().value())
2622  return emitOpError()
2623  << "simdlen clause and safelen clause are both present, but the "
2624  "simdlen value is not less than or equal to safelen value";
2625 
2626  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2627  return failure();
2628 
2629  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2630  return failure();
2631 
2632  bool isCompositeChildLeaf =
2633  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2634 
2635  if (!isComposite() && isCompositeChildLeaf)
2636  return emitError()
2637  << "'omp.composite' attribute missing from composite wrapper";
2638 
2639  if (isComposite() && !isCompositeChildLeaf)
2640  return emitError()
2641  << "'omp.composite' attribute present in non-composite wrapper";
2642 
2643  return success();
2644 }
2645 
2646 LogicalResult SimdOp::verifyRegions() {
2647  if (getNestedWrapper())
2648  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2649 
2650  return success();
2651 }
2652 
2653 //===----------------------------------------------------------------------===//
2654 // Distribute construct [2.9.4.1]
2655 //===----------------------------------------------------------------------===//
2656 
2657 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2658  const DistributeOperands &clauses) {
2659  DistributeOp::build(builder, state, clauses.allocateVars,
2660  clauses.allocatorVars, clauses.distScheduleStatic,
2661  clauses.distScheduleChunkSize, clauses.order,
2662  clauses.orderMod, clauses.privateVars,
2663  makeArrayAttr(builder.getContext(), clauses.privateSyms),
2664  clauses.privateNeedsBarrier);
2665 }
2666 
2667 LogicalResult DistributeOp::verify() {
2668  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2669  return emitOpError() << "chunk size set without "
2670  "dist_schedule_static being present";
2671 
2672  if (getAllocateVars().size() != getAllocatorVars().size())
2673  return emitError(
2674  "expected equal sizes for allocate and allocator variables");
2675 
2676  return success();
2677 }
2678 
2679 LogicalResult DistributeOp::verifyRegions() {
2680  if (LoopWrapperInterface nested = getNestedWrapper()) {
2681  if (!isComposite())
2682  return emitError()
2683  << "'omp.composite' attribute missing from composite wrapper";
2684  // Check for the allowed leaf constructs that may appear in a composite
2685  // construct directly after DISTRIBUTE.
2686  if (isa<WsloopOp>(nested)) {
2687  Operation *parentOp = (*this)->getParentOp();
2688  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2689  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2690  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2691  "when a composite 'omp.parallel' is the direct "
2692  "parent";
2693  }
2694  } else if (!isa<SimdOp>(nested))
2695  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2696  "'omp.wsloop'";
2697  } else if (isComposite()) {
2698  return emitError()
2699  << "'omp.composite' attribute present in non-composite wrapper";
2700  }
2701 
2702  return success();
2703 }
2704 
2705 //===----------------------------------------------------------------------===//
2706 // DeclareMapperOp / DeclareMapperInfoOp
2707 //===----------------------------------------------------------------------===//
2708 
2709 LogicalResult DeclareMapperInfoOp::verify() {
2710  return verifyMapClause(*this, getMapVars());
2711 }
2712 
2713 LogicalResult DeclareMapperOp::verifyRegions() {
2714  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2715  getRegion().getBlocks().front().getTerminator()))
2716  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2717 
2718  return success();
2719 }
2720 
2721 //===----------------------------------------------------------------------===//
2722 // DeclareReductionOp
2723 //===----------------------------------------------------------------------===//
2724 
2725 LogicalResult DeclareReductionOp::verifyRegions() {
2726  if (!getAllocRegion().empty()) {
2727  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2728  if (yieldOp.getResults().size() != 1 ||
2729  yieldOp.getResults().getTypes()[0] != getType())
2730  return emitOpError() << "expects alloc region to yield a value "
2731  "of the reduction type";
2732  }
2733  }
2734 
2735  if (getInitializerRegion().empty())
2736  return emitOpError() << "expects non-empty initializer region";
2737  Block &initializerEntryBlock = getInitializerRegion().front();
2738 
2739  if (initializerEntryBlock.getNumArguments() == 1) {
2740  if (!getAllocRegion().empty())
2741  return emitOpError() << "expects two arguments to the initializer region "
2742  "when an allocation region is used";
2743  } else if (initializerEntryBlock.getNumArguments() == 2) {
2744  if (getAllocRegion().empty())
2745  return emitOpError() << "expects one argument to the initializer region "
2746  "when no allocation region is used";
2747  } else {
2748  return emitOpError()
2749  << "expects one or two arguments to the initializer region";
2750  }
2751 
2752  for (mlir::Value arg : initializerEntryBlock.getArguments())
2753  if (arg.getType() != getType())
2754  return emitOpError() << "expects initializer region argument to match "
2755  "the reduction type";
2756 
2757  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2758  if (yieldOp.getResults().size() != 1 ||
2759  yieldOp.getResults().getTypes()[0] != getType())
2760  return emitOpError() << "expects initializer region to yield a value "
2761  "of the reduction type";
2762  }
2763 
2764  if (getReductionRegion().empty())
2765  return emitOpError() << "expects non-empty reduction region";
2766  Block &reductionEntryBlock = getReductionRegion().front();
2767  if (reductionEntryBlock.getNumArguments() != 2 ||
2768  reductionEntryBlock.getArgumentTypes()[0] !=
2769  reductionEntryBlock.getArgumentTypes()[1] ||
2770  reductionEntryBlock.getArgumentTypes()[0] != getType())
2771  return emitOpError() << "expects reduction region with two arguments of "
2772  "the reduction type";
2773  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2774  if (yieldOp.getResults().size() != 1 ||
2775  yieldOp.getResults().getTypes()[0] != getType())
2776  return emitOpError() << "expects reduction region to yield a value "
2777  "of the reduction type";
2778  }
2779 
2780  if (!getAtomicReductionRegion().empty()) {
2781  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2782  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2783  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2784  atomicReductionEntryBlock.getArgumentTypes()[1])
2785  return emitOpError() << "expects atomic reduction region with two "
2786  "arguments of the same type";
2787  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2788  atomicReductionEntryBlock.getArgumentTypes()[0]);
2789  if (!ptrType ||
2790  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2791  return emitOpError() << "expects atomic reduction region arguments to "
2792  "be accumulators containing the reduction type";
2793  }
2794 
2795  if (getCleanupRegion().empty())
2796  return success();
2797  Block &cleanupEntryBlock = getCleanupRegion().front();
2798  if (cleanupEntryBlock.getNumArguments() != 1 ||
2799  cleanupEntryBlock.getArgument(0).getType() != getType())
2800  return emitOpError() << "expects cleanup region with one argument "
2801  "of the reduction type";
2802 
2803  return success();
2804 }
2805 
2806 //===----------------------------------------------------------------------===//
2807 // TaskOp
2808 //===----------------------------------------------------------------------===//
2809 
2810 void TaskOp::build(OpBuilder &builder, OperationState &state,
2811  const TaskOperands &clauses) {
2812  MLIRContext *ctx = builder.getContext();
2813  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2814  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2815  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2816  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2817  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2818  clauses.priority, /*private_vars=*/clauses.privateVars,
2819  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2820  clauses.privateNeedsBarrier, clauses.untied,
2821  clauses.eventHandle);
2822 }
2823 
2824 LogicalResult TaskOp::verify() {
2825  LogicalResult verifyDependVars =
2826  verifyDependVarList(*this, getDependKinds(), getDependVars());
2827  return failed(verifyDependVars)
2828  ? verifyDependVars
2829  : verifyReductionVarList(*this, getInReductionSyms(),
2830  getInReductionVars(),
2831  getInReductionByref());
2832 }
2833 
2834 //===----------------------------------------------------------------------===//
2835 // TaskgroupOp
2836 //===----------------------------------------------------------------------===//
2837 
2838 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2839  const TaskgroupOperands &clauses) {
2840  MLIRContext *ctx = builder.getContext();
2841  TaskgroupOp::build(builder, state, clauses.allocateVars,
2842  clauses.allocatorVars, clauses.taskReductionVars,
2843  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2844  makeArrayAttr(ctx, clauses.taskReductionSyms));
2845 }
2846 
2847 LogicalResult TaskgroupOp::verify() {
2848  return verifyReductionVarList(*this, getTaskReductionSyms(),
2849  getTaskReductionVars(),
2850  getTaskReductionByref());
2851 }
2852 
2853 //===----------------------------------------------------------------------===//
2854 // TaskloopOp
2855 //===----------------------------------------------------------------------===//
2856 
2857 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2858  const TaskloopOperands &clauses) {
2859  MLIRContext *ctx = builder.getContext();
2860  TaskloopOp::build(
2861  builder, state, clauses.allocateVars, clauses.allocatorVars,
2862  clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2863  clauses.inReductionVars,
2864  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2865  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2866  clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2867  /*private_vars=*/clauses.privateVars,
2868  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2869  clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2870  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2871  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2872 }
2873 
2874 LogicalResult TaskloopOp::verify() {
2875  if (getAllocateVars().size() != getAllocatorVars().size())
2876  return emitError(
2877  "expected equal sizes for allocate and allocator variables");
2878  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2879  getReductionVars(), getReductionByref())) ||
2880  failed(verifyReductionVarList(*this, getInReductionSyms(),
2881  getInReductionVars(),
2882  getInReductionByref())))
2883  return failure();
2884 
2885  if (!getReductionVars().empty() && getNogroup())
2886  return emitError("if a reduction clause is present on the taskloop "
2887  "directive, the nogroup clause must not be specified");
2888  for (auto var : getReductionVars()) {
2889  if (llvm::is_contained(getInReductionVars(), var))
2890  return emitError("the same list item cannot appear in both a reduction "
2891  "and an in_reduction clause");
2892  }
2893 
2894  if (getGrainsize() && getNumTasks()) {
2895  return emitError(
2896  "the grainsize clause and num_tasks clause are mutually exclusive and "
2897  "may not appear on the same taskloop directive");
2898  }
2899 
2900  return success();
2901 }
2902 
2903 LogicalResult TaskloopOp::verifyRegions() {
2904  if (LoopWrapperInterface nested = getNestedWrapper()) {
2905  if (!isComposite())
2906  return emitError()
2907  << "'omp.composite' attribute missing from composite wrapper";
2908 
2909  // Check for the allowed leaf constructs that may appear in a composite
2910  // construct directly after TASKLOOP.
2911  if (!isa<SimdOp>(nested))
2912  return emitError() << "only supported nested wrapper is 'omp.simd'";
2913  } else if (isComposite()) {
2914  return emitError()
2915  << "'omp.composite' attribute present in non-composite wrapper";
2916  }
2917 
2918  return success();
2919 }
2920 
2921 //===----------------------------------------------------------------------===//
2922 // LoopNestOp
2923 //===----------------------------------------------------------------------===//
2924 
2925 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2926  // Parse an opening `(` followed by induction variables followed by `)`
2929  Type loopVarType;
2930  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2931  parser.parseColonType(loopVarType) ||
2932  // Parse loop bounds.
2933  parser.parseEqual() ||
2934  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2935  parser.parseKeyword("to") ||
2936  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2937  return failure();
2938 
2939  for (auto &iv : ivs)
2940  iv.type = loopVarType;
2941 
2942  // Parse "inclusive" flag.
2943  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2944  result.addAttribute("loop_inclusive",
2945  UnitAttr::get(parser.getBuilder().getContext()));
2946 
2947  // Parse step values.
2949  if (parser.parseKeyword("step") ||
2950  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2951  return failure();
2952 
2953  // Parse the body.
2954  Region *region = result.addRegion();
2955  if (parser.parseRegion(*region, ivs))
2956  return failure();
2957 
2958  // Resolve operands.
2959  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2960  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2961  parser.resolveOperands(steps, loopVarType, result.operands))
2962  return failure();
2963 
2964  // Parse the optional attribute list.
2965  return parser.parseOptionalAttrDict(result.attributes);
2966 }
2967 
2969  Region &region = getRegion();
2970  auto args = region.getArguments();
2971  p << " (" << args << ") : " << args[0].getType() << " = ("
2972  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2973  if (getLoopInclusive())
2974  p << "inclusive ";
2975  p << "step (" << getLoopSteps() << ") ";
2976  p.printRegion(region, /*printEntryBlockArgs=*/false);
2977 }
2978 
2979 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2980  const LoopNestOperands &clauses) {
2981  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2982  clauses.loopUpperBounds, clauses.loopSteps,
2983  clauses.loopInclusive);
2984 }
2985 
2986 LogicalResult LoopNestOp::verify() {
2987  if (getLoopLowerBounds().empty())
2988  return emitOpError() << "must represent at least one loop";
2989 
2990  if (getLoopLowerBounds().size() != getIVs().size())
2991  return emitOpError() << "number of range arguments and IVs do not match";
2992 
2993  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2994  if (lb.getType() != iv.getType())
2995  return emitOpError()
2996  << "range argument type does not match corresponding IV type";
2997  }
2998 
2999  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3000  return emitOpError() << "expects parent op to be a loop wrapper";
3001 
3002  return success();
3003 }
3004 
3005 void LoopNestOp::gatherWrappers(
3007  Operation *parent = (*this)->getParentOp();
3008  while (auto wrapper =
3009  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3010  wrappers.push_back(wrapper);
3011  parent = parent->getParentOp();
3012  }
3013 }
3014 
3015 //===----------------------------------------------------------------------===//
3016 // Critical construct (2.17.1)
3017 //===----------------------------------------------------------------------===//
3018 
3019 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3020  const CriticalDeclareOperands &clauses) {
3021  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3022 }
3023 
3024 LogicalResult CriticalDeclareOp::verify() {
3025  return verifySynchronizationHint(*this, getHint());
3026 }
3027 
3028 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3029  if (getNameAttr()) {
3030  SymbolRefAttr symbolRef = getNameAttr();
3031  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3032  *this, symbolRef);
3033  if (!decl) {
3034  return emitOpError() << "expected symbol reference " << symbolRef
3035  << " to point to a critical declaration";
3036  }
3037  }
3038 
3039  return success();
3040 }
3041 
3042 //===----------------------------------------------------------------------===//
3043 // Ordered construct
3044 //===----------------------------------------------------------------------===//
3045 
3046 static LogicalResult verifyOrderedParent(Operation &op) {
3047  bool hasRegion = op.getNumRegions() > 0;
3048  auto loopOp = op.getParentOfType<LoopNestOp>();
3049  if (!loopOp) {
3050  if (hasRegion)
3051  return success();
3052 
3053  // TODO: Consider if this needs to be the case only for the standalone
3054  // variant of the ordered construct.
3055  return op.emitOpError() << "must be nested inside of a loop";
3056  }
3057 
3058  Operation *wrapper = loopOp->getParentOp();
3059  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3060  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3061  if (!orderedAttr)
3062  return op.emitOpError() << "the enclosing worksharing-loop region must "
3063  "have an ordered clause";
3064 
3065  if (hasRegion && orderedAttr.getInt() != 0)
3066  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3067  "have a parameter present";
3068 
3069  if (!hasRegion && orderedAttr.getInt() == 0)
3070  return op.emitOpError() << "the enclosing loop's ordered clause must "
3071  "have a parameter present";
3072  } else if (!isa<SimdOp>(wrapper)) {
3073  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3074  "or worksharing simd loop";
3075  }
3076  return success();
3077 }
3078 
3079 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3080  const OrderedOperands &clauses) {
3081  OrderedOp::build(builder, state, clauses.doacrossDependType,
3082  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3083 }
3084 
3085 LogicalResult OrderedOp::verify() {
3086  if (failed(verifyOrderedParent(**this)))
3087  return failure();
3088 
3089  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3090  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3091  return emitOpError() << "number of variables in depend clause does not "
3092  << "match number of iteration variables in the "
3093  << "doacross loop";
3094 
3095  return success();
3096 }
3097 
3098 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3099  const OrderedRegionOperands &clauses) {
3100  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3101 }
3102 
3103 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3104 
3105 //===----------------------------------------------------------------------===//
3106 // TaskwaitOp
3107 //===----------------------------------------------------------------------===//
3108 
3109 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3110  const TaskwaitOperands &clauses) {
3111  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3112  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3113  /*depend_vars=*/{}, /*nowait=*/nullptr);
3114 }
3115 
3116 //===----------------------------------------------------------------------===//
3117 // Verifier for AtomicReadOp
3118 //===----------------------------------------------------------------------===//
3119 
3120 LogicalResult AtomicReadOp::verify() {
3121  if (verifyCommon().failed())
3122  return mlir::failure();
3123 
3124  if (auto mo = getMemoryOrder()) {
3125  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3126  *mo == ClauseMemoryOrderKind::Release) {
3127  return emitError(
3128  "memory-order must not be acq_rel or release for atomic reads");
3129  }
3130  }
3131  return verifySynchronizationHint(*this, getHint());
3132 }
3133 
3134 //===----------------------------------------------------------------------===//
3135 // Verifier for AtomicWriteOp
3136 //===----------------------------------------------------------------------===//
3137 
3138 LogicalResult AtomicWriteOp::verify() {
3139  if (verifyCommon().failed())
3140  return mlir::failure();
3141 
3142  if (auto mo = getMemoryOrder()) {
3143  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3144  *mo == ClauseMemoryOrderKind::Acquire) {
3145  return emitError(
3146  "memory-order must not be acq_rel or acquire for atomic writes");
3147  }
3148  }
3149  return verifySynchronizationHint(*this, getHint());
3150 }
3151 
3152 //===----------------------------------------------------------------------===//
3153 // Verifier for AtomicUpdateOp
3154 //===----------------------------------------------------------------------===//
3155 
3156 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3157  PatternRewriter &rewriter) {
3158  if (op.isNoOp()) {
3159  rewriter.eraseOp(op);
3160  return success();
3161  }
3162  if (Value writeVal = op.getWriteOpVal()) {
3163  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3164  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3165  return success();
3166  }
3167  return failure();
3168 }
3169 
3170 LogicalResult AtomicUpdateOp::verify() {
3171  if (verifyCommon().failed())
3172  return mlir::failure();
3173 
3174  if (auto mo = getMemoryOrder()) {
3175  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3176  *mo == ClauseMemoryOrderKind::Acquire) {
3177  return emitError(
3178  "memory-order must not be acq_rel or acquire for atomic updates");
3179  }
3180  }
3181 
3182  return verifySynchronizationHint(*this, getHint());
3183 }
3184 
3185 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3186 
3187 //===----------------------------------------------------------------------===//
3188 // Verifier for AtomicCaptureOp
3189 //===----------------------------------------------------------------------===//
3190 
3191 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3192  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3193  return op;
3194  return dyn_cast<AtomicReadOp>(getSecondOp());
3195 }
3196 
3197 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3198  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3199  return op;
3200  return dyn_cast<AtomicWriteOp>(getSecondOp());
3201 }
3202 
3203 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3204  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3205  return op;
3206  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3207 }
3208 
3209 LogicalResult AtomicCaptureOp::verify() {
3210  return verifySynchronizationHint(*this, getHint());
3211 }
3212 
3213 LogicalResult AtomicCaptureOp::verifyRegions() {
3214  if (verifyRegionsCommon().failed())
3215  return mlir::failure();
3216 
3217  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3218  return emitOpError(
3219  "operations inside capture region must not have hint clause");
3220 
3221  if (getFirstOp()->getAttr("memory_order") ||
3222  getSecondOp()->getAttr("memory_order"))
3223  return emitOpError(
3224  "operations inside capture region must not have memory_order clause");
3225  return success();
3226 }
3227 
3228 //===----------------------------------------------------------------------===//
3229 // CancelOp
3230 //===----------------------------------------------------------------------===//
3231 
3232 void CancelOp::build(OpBuilder &builder, OperationState &state,
3233  const CancelOperands &clauses) {
3234  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3235 }
3236 
3238  Operation *parent = thisOp->getParentOp();
3239  while (parent) {
3240  if (parent->getDialect() == thisOp->getDialect())
3241  return parent;
3242  parent = parent->getParentOp();
3243  }
3244  return nullptr;
3245 }
3246 
3247 LogicalResult CancelOp::verify() {
3248  ClauseCancellationConstructType cct = getCancelDirective();
3249  // The next OpenMP operation in the chain of parents
3250  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3251  if (!structuralParent)
3252  return emitOpError() << "Orphaned cancel construct";
3253 
3254  if ((cct == ClauseCancellationConstructType::Parallel) &&
3255  !mlir::isa<ParallelOp>(structuralParent)) {
3256  return emitOpError() << "cancel parallel must appear "
3257  << "inside a parallel region";
3258  }
3259  if (cct == ClauseCancellationConstructType::Loop) {
3260  // structural parent will be omp.loop_nest, directly nested inside
3261  // omp.wsloop
3262  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3263 
3264  if (!wsloopOp) {
3265  return emitOpError()
3266  << "cancel loop must appear inside a worksharing-loop region";
3267  }
3268  if (wsloopOp.getNowaitAttr()) {
3269  return emitError() << "A worksharing construct that is canceled "
3270  << "must not have a nowait clause";
3271  }
3272  if (wsloopOp.getOrderedAttr()) {
3273  return emitError() << "A worksharing construct that is canceled "
3274  << "must not have an ordered clause";
3275  }
3276 
3277  } else if (cct == ClauseCancellationConstructType::Sections) {
3278  // structural parent will be an omp.section, directly nested inside
3279  // omp.sections
3280  auto sectionsOp =
3281  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3282  if (!sectionsOp) {
3283  return emitOpError() << "cancel sections must appear "
3284  << "inside a sections region";
3285  }
3286  if (sectionsOp.getNowait()) {
3287  return emitError() << "A sections construct that is canceled "
3288  << "must not have a nowait clause";
3289  }
3290  }
3291  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3292  (!mlir::isa<omp::TaskOp>(structuralParent) &&
3293  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
3294  return emitOpError() << "cancel taskgroup must appear "
3295  << "inside a task region";
3296  }
3297  return success();
3298 }
3299 
3300 //===----------------------------------------------------------------------===//
3301 // CancellationPointOp
3302 //===----------------------------------------------------------------------===//
3303 
3304 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3305  const CancellationPointOperands &clauses) {
3306  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3307 }
3308 
3309 LogicalResult CancellationPointOp::verify() {
3310  ClauseCancellationConstructType cct = getCancelDirective();
3311  // The next OpenMP operation in the chain of parents
3312  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3313  if (!structuralParent)
3314  return emitOpError() << "Orphaned cancellation point";
3315 
3316  if ((cct == ClauseCancellationConstructType::Parallel) &&
3317  !mlir::isa<ParallelOp>(structuralParent)) {
3318  return emitOpError() << "cancellation point parallel must appear "
3319  << "inside a parallel region";
3320  }
3321  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3322  // find the wsloop
3323  if ((cct == ClauseCancellationConstructType::Loop) &&
3324  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3325  return emitOpError() << "cancellation point loop must appear "
3326  << "inside a worksharing-loop region";
3327  }
3328  if ((cct == ClauseCancellationConstructType::Sections) &&
3329  !mlir::isa<omp::SectionOp>(structuralParent)) {
3330  return emitOpError() << "cancellation point sections must appear "
3331  << "inside a sections region";
3332  }
3333  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3334  !mlir::isa<omp::TaskOp>(structuralParent)) {
3335  return emitOpError() << "cancellation point taskgroup must appear "
3336  << "inside a task region";
3337  }
3338  return success();
3339 }
3340 
3341 //===----------------------------------------------------------------------===//
3342 // MapBoundsOp
3343 //===----------------------------------------------------------------------===//
3344 
3345 LogicalResult MapBoundsOp::verify() {
3346  auto extent = getExtent();
3347  auto upperbound = getUpperBound();
3348  if (!extent && !upperbound)
3349  return emitError("expected extent or upperbound.");
3350  return success();
3351 }
3352 
3353 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3354  TypeRange /*result_types*/, StringAttr symName,
3355  TypeAttr type) {
3356  PrivateClauseOp::build(
3357  odsBuilder, odsState, symName, type,
3359  DataSharingClauseType::Private));
3360 }
3361 
3362 LogicalResult PrivateClauseOp::verifyRegions() {
3363  Type argType = getArgType();
3364  auto verifyTerminator = [&](Operation *terminator,
3365  bool yieldsValue) -> LogicalResult {
3366  if (!terminator->getBlock()->getSuccessors().empty())
3367  return success();
3368 
3369  if (!llvm::isa<YieldOp>(terminator))
3370  return mlir::emitError(terminator->getLoc())
3371  << "expected exit block terminator to be an `omp.yield` op.";
3372 
3373  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3374  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3375 
3376  if (!yieldsValue) {
3377  if (yieldedTypes.empty())
3378  return success();
3379 
3380  return mlir::emitError(terminator->getLoc())
3381  << "Did not expect any values to be yielded.";
3382  }
3383 
3384  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3385  return success();
3386 
3387  auto error = mlir::emitError(yieldOp.getLoc())
3388  << "Invalid yielded value. Expected type: " << argType
3389  << ", got: ";
3390 
3391  if (yieldedTypes.empty())
3392  error << "None";
3393  else
3394  error << yieldedTypes;
3395 
3396  return error;
3397  };
3398 
3399  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3400  StringRef regionName,
3401  bool yieldsValue) -> LogicalResult {
3402  assert(!region.empty());
3403 
3404  if (region.getNumArguments() != expectedNumArgs)
3405  return mlir::emitError(region.getLoc())
3406  << "`" << regionName << "`: "
3407  << "expected " << expectedNumArgs
3408  << " region arguments, got: " << region.getNumArguments();
3409 
3410  for (Block &block : region) {
3411  // MLIR will verify the absence of the terminator for us.
3412  if (!block.mightHaveTerminator())
3413  continue;
3414 
3415  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3416  return failure();
3417  }
3418 
3419  return success();
3420  };
3421 
3422  // Ensure all of the region arguments have the same type
3423  for (Region *region : getRegions())
3424  for (Type ty : region->getArgumentTypes())
3425  if (ty != argType)
3426  return emitError() << "Region argument type mismatch: got " << ty
3427  << " expected " << argType << ".";
3428 
3429  mlir::Region &initRegion = getInitRegion();
3430  if (!initRegion.empty() &&
3431  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3432  /*yieldsValue=*/true)))
3433  return failure();
3434 
3435  DataSharingClauseType dsType = getDataSharingType();
3436 
3437  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3438  return emitError("`private` clauses do not require a `copy` region.");
3439 
3440  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3441  return emitError(
3442  "`firstprivate` clauses require at least a `copy` region.");
3443 
3444  if (dsType == DataSharingClauseType::FirstPrivate &&
3445  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3446  /*yieldsValue=*/true)))
3447  return failure();
3448 
3449  if (!getDeallocRegion().empty() &&
3450  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3451  /*yieldsValue=*/false)))
3452  return failure();
3453 
3454  return success();
3455 }
3456 
3457 //===----------------------------------------------------------------------===//
3458 // Spec 5.2: Masked construct (10.5)
3459 //===----------------------------------------------------------------------===//
3460 
3461 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3462  const MaskedOperands &clauses) {
3463  MaskedOp::build(builder, state, clauses.filteredThreadId);
3464 }
3465 
3466 //===----------------------------------------------------------------------===//
3467 // Spec 5.2: Scan construct (5.6)
3468 //===----------------------------------------------------------------------===//
3469 
3470 void ScanOp::build(OpBuilder &builder, OperationState &state,
3471  const ScanOperands &clauses) {
3472  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3473 }
3474 
3475 LogicalResult ScanOp::verify() {
3476  if (hasExclusiveVars() == hasInclusiveVars())
3477  return emitError(
3478  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3479  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3480  if (parentWsLoopOp.getReductionModAttr() &&
3481  parentWsLoopOp.getReductionModAttr().getValue() ==
3482  ReductionModifier::inscan)
3483  return success();
3484  }
3485  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3486  if (parentSimdOp.getReductionModAttr() &&
3487  parentSimdOp.getReductionModAttr().getValue() ==
3488  ReductionModifier::inscan)
3489  return success();
3490  }
3491  return emitError("SCAN directive needs to be enclosed within a parent "
3492  "worksharing loop construct or SIMD construct with INSCAN "
3493  "reduction modifier");
3494 }
3495 
3496 #define GET_ATTRDEF_CLASSES
3497 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3498 
3499 #define GET_OP_CLASSES
3500 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3501 
3502 #define GET_TYPEDEF_CLASSES
3503 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:753
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1293
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
IntegerType getI64Type()
Definition: Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
MLIRContext * getContext() const
Definition: Builders.h:55
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:772
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:768
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:307
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:880
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.