MLIR  20.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include <cstddef>
35 #include <iterator>
36 #include <optional>
37 #include <variant>
38 
39 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
43 
44 using namespace mlir;
45 using namespace mlir::omp;
46 
47 static ArrayAttr makeArrayAttr(MLIRContext *context,
49  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
50 }
51 
52 static DenseBoolArrayAttr
54  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
55 }
56 
57 namespace {
58 struct MemRefPointerLikeModel
59  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
60  MemRefType> {
61  Type getElementType(Type pointer) const {
62  return llvm::cast<MemRefType>(pointer).getElementType();
63  }
64 };
65 
66 struct LLVMPointerPointerLikeModel
67  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
68  LLVM::LLVMPointerType> {
69  Type getElementType(Type pointer) const { return Type(); }
70 };
71 } // namespace
72 
73 void OpenMPDialect::initialize() {
74  addOperations<
75 #define GET_OP_LIST
76 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
77  >();
78  addAttributes<
79 #define GET_ATTRDEF_LIST
80 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
81  >();
82  addTypes<
83 #define GET_TYPEDEF_LIST
84 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
85  >();
86 
87  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
88 
89  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
90  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
91  *getContext());
92 
93  // Attach default offload module interface to module op to access
94  // offload functionality through
95  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
96  *getContext());
97 
98  // Attach default declare target interfaces to operations which can be marked
99  // as declare target (Global Operations and Functions/Subroutines in dialects
100  // that Fortran (or other languages that lower to MLIR) translates too
101  mlir::LLVM::GlobalOp::attachInterface<
103  *getContext());
104  mlir::LLVM::LLVMFuncOp::attachInterface<
106  *getContext());
107  mlir::func::FuncOp::attachInterface<
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Parser and printer for Allocate Clause
113 //===----------------------------------------------------------------------===//
114 
115 /// Parse an allocate clause with allocators and a list of operands with types.
116 ///
117 /// allocate-operand-list :: = allocate-operand |
118 /// allocator-operand `,` allocate-operand-list
119 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
120 /// ssa-id-and-type ::= ssa-id `:` type
121 static ParseResult parseAllocateAndAllocator(
122  OpAsmParser &parser,
124  SmallVectorImpl<Type> &allocateTypes,
126  SmallVectorImpl<Type> &allocatorTypes) {
127 
128  return parser.parseCommaSeparatedList([&]() {
130  Type type;
131  if (parser.parseOperand(operand) || parser.parseColonType(type))
132  return failure();
133  allocatorVars.push_back(operand);
134  allocatorTypes.push_back(type);
135  if (parser.parseArrow())
136  return failure();
137  if (parser.parseOperand(operand) || parser.parseColonType(type))
138  return failure();
139 
140  allocateVars.push_back(operand);
141  allocateTypes.push_back(type);
142  return success();
143  });
144 }
145 
146 /// Print allocate clause
148  OperandRange allocateVars,
149  TypeRange allocateTypes,
150  OperandRange allocatorVars,
151  TypeRange allocatorTypes) {
152  for (unsigned i = 0; i < allocateVars.size(); ++i) {
153  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
154  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
155  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
156  }
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Parser and printer for a clause attribute (StringEnumAttr)
161 //===----------------------------------------------------------------------===//
162 
163 template <typename ClauseAttr>
164 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
165  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
166  StringRef enumStr;
167  SMLoc loc = parser.getCurrentLocation();
168  if (parser.parseKeyword(&enumStr))
169  return failure();
170  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
171  attr = ClauseAttr::get(parser.getContext(), *enumValue);
172  return success();
173  }
174  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
175 }
176 
177 template <typename ClauseAttr>
178 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
179  p << stringifyEnum(attr.getValue());
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Parser and printer for Linear Clause
184 //===----------------------------------------------------------------------===//
185 
186 /// linear ::= `linear` `(` linear-list `)`
187 /// linear-list := linear-val | linear-val linear-list
188 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
189 static ParseResult parseLinearClause(
190  OpAsmParser &parser,
192  SmallVectorImpl<Type> &linearTypes,
194  return parser.parseCommaSeparatedList([&]() {
196  Type type;
198  if (parser.parseOperand(var) || parser.parseEqual() ||
199  parser.parseOperand(stepVar) || parser.parseColonType(type))
200  return failure();
201 
202  linearVars.push_back(var);
203  linearTypes.push_back(type);
204  linearStepVars.push_back(stepVar);
205  return success();
206  });
207 }
208 
209 /// Print Linear Clause
211  ValueRange linearVars, TypeRange linearTypes,
212  ValueRange linearStepVars) {
213  size_t linearVarsSize = linearVars.size();
214  for (unsigned i = 0; i < linearVarsSize; ++i) {
215  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
216  p << linearVars[i];
217  if (linearStepVars.size() > i)
218  p << " = " << linearStepVars[i];
219  p << " : " << linearVars[i].getType() << separator;
220  }
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Verifier for Nontemporal Clause
225 //===----------------------------------------------------------------------===//
226 
227 static LogicalResult verifyNontemporalClause(Operation *op,
228  OperandRange nontemporalVars) {
229 
230  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
231  DenseSet<Value> nontemporalItems;
232  for (const auto &it : nontemporalVars)
233  if (!nontemporalItems.insert(it).second)
234  return op->emitOpError() << "nontemporal variable used more than once";
235 
236  return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Parser, verifier and printer for Aligned Clause
241 //===----------------------------------------------------------------------===//
242 static LogicalResult verifyAlignedClause(Operation *op,
243  std::optional<ArrayAttr> alignments,
244  OperandRange alignedVars) {
245  // Check if number of alignment values equals to number of aligned variables
246  if (!alignedVars.empty()) {
247  if (!alignments || alignments->size() != alignedVars.size())
248  return op->emitOpError()
249  << "expected as many alignment values as aligned variables";
250  } else {
251  if (alignments)
252  return op->emitOpError() << "unexpected alignment values attribute";
253  return success();
254  }
255 
256  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
257  DenseSet<Value> alignedItems;
258  for (auto it : alignedVars)
259  if (!alignedItems.insert(it).second)
260  return op->emitOpError() << "aligned variable used more than once";
261 
262  if (!alignments)
263  return success();
264 
265  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
266  for (unsigned i = 0; i < (*alignments).size(); ++i) {
267  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
268  if (intAttr.getValue().sle(0))
269  return op->emitOpError() << "alignment should be greater than 0";
270  } else {
271  return op->emitOpError() << "expected integer alignment";
272  }
273  }
274 
275  return success();
276 }
277 
278 /// aligned ::= `aligned` `(` aligned-list `)`
279 /// aligned-list := aligned-val | aligned-val aligned-list
280 /// aligned-val := ssa-id-and-type `->` alignment
281 static ParseResult
284  SmallVectorImpl<Type> &alignedTypes,
285  ArrayAttr &alignmentsAttr) {
286  SmallVector<Attribute> alignmentVec;
287  if (failed(parser.parseCommaSeparatedList([&]() {
288  if (parser.parseOperand(alignedVars.emplace_back()) ||
289  parser.parseColonType(alignedTypes.emplace_back()) ||
290  parser.parseArrow() ||
291  parser.parseAttribute(alignmentVec.emplace_back())) {
292  return failure();
293  }
294  return success();
295  })))
296  return failure();
297  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
298  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
299  return success();
300 }
301 
302 /// Print Aligned Clause
304  ValueRange alignedVars, TypeRange alignedTypes,
305  std::optional<ArrayAttr> alignments) {
306  for (unsigned i = 0; i < alignedVars.size(); ++i) {
307  if (i != 0)
308  p << ", ";
309  p << alignedVars[i] << " : " << alignedVars[i].getType();
310  p << " -> " << (*alignments)[i];
311  }
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Parser, printer and verifier for Schedule Clause
316 //===----------------------------------------------------------------------===//
317 
318 static ParseResult
320  SmallVectorImpl<SmallString<12>> &modifiers) {
321  if (modifiers.size() > 2)
322  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
323  for (const auto &mod : modifiers) {
324  // Translate the string. If it has no value, then it was not a valid
325  // modifier!
326  auto symbol = symbolizeScheduleModifier(mod);
327  if (!symbol)
328  return parser.emitError(parser.getNameLoc())
329  << " unknown modifier type: " << mod;
330  }
331 
332  // If we have one modifier that is "simd", then stick a "none" modiifer in
333  // index 0.
334  if (modifiers.size() == 1) {
335  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
336  modifiers.push_back(modifiers[0]);
337  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
338  }
339  } else if (modifiers.size() == 2) {
340  // If there are two modifier:
341  // First modifier should not be simd, second one should be simd
342  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
343  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
344  return parser.emitError(parser.getNameLoc())
345  << " incorrect modifier order";
346  }
347  return success();
348 }
349 
350 /// schedule ::= `schedule` `(` sched-list `)`
351 /// sched-list ::= sched-val | sched-val sched-list |
352 /// sched-val `,` sched-modifier
353 /// sched-val ::= sched-with-chunk | sched-wo-chunk
354 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
355 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
356 /// sched-wo-chunk ::= `auto` | `runtime`
357 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
358 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
359 static ParseResult
360 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
361  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
362  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
363  Type &chunkType) {
364  StringRef keyword;
365  if (parser.parseKeyword(&keyword))
366  return failure();
367  std::optional<mlir::omp::ClauseScheduleKind> schedule =
368  symbolizeClauseScheduleKind(keyword);
369  if (!schedule)
370  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
371 
372  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
373  switch (*schedule) {
374  case ClauseScheduleKind::Static:
375  case ClauseScheduleKind::Dynamic:
376  case ClauseScheduleKind::Guided:
377  if (succeeded(parser.parseOptionalEqual())) {
378  chunkSize = OpAsmParser::UnresolvedOperand{};
379  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
380  return failure();
381  } else {
382  chunkSize = std::nullopt;
383  }
384  break;
385  case ClauseScheduleKind::Auto:
387  chunkSize = std::nullopt;
388  }
389 
390  // If there is a comma, we have one or more modifiers..
391  SmallVector<SmallString<12>> modifiers;
392  while (succeeded(parser.parseOptionalComma())) {
393  StringRef mod;
394  if (parser.parseKeyword(&mod))
395  return failure();
396  modifiers.push_back(mod);
397  }
398 
399  if (verifyScheduleModifiers(parser, modifiers))
400  return failure();
401 
402  if (!modifiers.empty()) {
403  SMLoc loc = parser.getCurrentLocation();
404  if (std::optional<ScheduleModifier> mod =
405  symbolizeScheduleModifier(modifiers[0])) {
406  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
407  } else {
408  return parser.emitError(loc, "invalid schedule modifier");
409  }
410  // Only SIMD attribute is allowed here!
411  if (modifiers.size() > 1) {
412  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
413  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
414  }
415  }
416 
417  return success();
418 }
419 
420 /// Print schedule clause
422  ClauseScheduleKindAttr scheduleKind,
423  ScheduleModifierAttr scheduleMod,
424  UnitAttr scheduleSimd, Value scheduleChunk,
425  Type scheduleChunkType) {
426  p << stringifyClauseScheduleKind(scheduleKind.getValue());
427  if (scheduleChunk)
428  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
429  if (scheduleMod)
430  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
431  if (scheduleSimd)
432  p << ", simd";
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // Parser and printer for Order Clause
437 //===----------------------------------------------------------------------===//
438 
439 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
440 // order-modifier ::= reproducible | unconstrained
441 static ParseResult parseOrderClause(OpAsmParser &parser,
442  ClauseOrderKindAttr &order,
443  OrderModifierAttr &orderMod) {
444  StringRef enumStr;
445  SMLoc loc = parser.getCurrentLocation();
446  if (parser.parseKeyword(&enumStr))
447  return failure();
448  if (std::optional<OrderModifier> enumValue =
449  symbolizeOrderModifier(enumStr)) {
450  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
451  if (parser.parseOptionalColon())
452  return failure();
453  loc = parser.getCurrentLocation();
454  if (parser.parseKeyword(&enumStr))
455  return failure();
456  }
457  if (std::optional<ClauseOrderKind> enumValue =
458  symbolizeClauseOrderKind(enumStr)) {
459  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
460  return success();
461  }
462  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
463 }
464 
466  ClauseOrderKindAttr order,
467  OrderModifierAttr orderMod) {
468  if (orderMod)
469  p << stringifyOrderModifier(orderMod.getValue()) << ":";
470  if (order)
471  p << stringifyClauseOrderKind(order.getValue());
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // Parsers for operations including clauses that define entry block arguments.
476 //===----------------------------------------------------------------------===//
477 
478 namespace {
479 struct MapParseArgs {
481  SmallVectorImpl<Type> &types;
483  SmallVectorImpl<Type> &types)
484  : vars(vars), types(types) {}
485 };
486 struct PrivateParseArgs {
489  ArrayAttr &syms;
491  SmallVectorImpl<Type> &types, ArrayAttr &syms)
492  : vars(vars), types(types), syms(syms) {}
493 };
494 struct ReductionParseArgs {
496  SmallVectorImpl<Type> &types;
497  DenseBoolArrayAttr &byref;
498  ArrayAttr &syms;
499  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
501  ArrayAttr &syms)
502  : vars(vars), types(types), byref(byref), syms(syms) {}
503 };
504 struct AllRegionParseArgs {
505  std::optional<ReductionParseArgs> inReductionArgs;
506  std::optional<MapParseArgs> mapArgs;
507  std::optional<PrivateParseArgs> privateArgs;
508  std::optional<ReductionParseArgs> reductionArgs;
509  std::optional<ReductionParseArgs> taskReductionArgs;
510  std::optional<MapParseArgs> useDeviceAddrArgs;
511  std::optional<MapParseArgs> useDevicePtrArgs;
512 };
513 } // namespace
514 
515 static ParseResult parseClauseWithRegionArgs(
516  OpAsmParser &parser,
518  SmallVectorImpl<Type> &types,
519  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
520  ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
521  SmallVector<SymbolRefAttr> symbolVec;
522  SmallVector<bool> isByRefVec;
523  unsigned regionArgOffset = regionPrivateArgs.size();
524 
525  if (parser.parseLParen())
526  return failure();
527 
528  if (parser.parseCommaSeparatedList([&]() {
529  if (byref)
530  isByRefVec.push_back(
531  parser.parseOptionalKeyword("byref").succeeded());
532 
533  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
534  return failure();
535 
536  if (parser.parseOperand(operands.emplace_back()) ||
537  parser.parseArrow() ||
538  parser.parseArgument(regionPrivateArgs.emplace_back()))
539  return failure();
540 
541  return success();
542  }))
543  return failure();
544 
545  if (parser.parseColon())
546  return failure();
547 
548  if (parser.parseCommaSeparatedList([&]() {
549  if (parser.parseType(types.emplace_back()))
550  return failure();
551 
552  return success();
553  }))
554  return failure();
555 
556  if (operands.size() != types.size())
557  return failure();
558 
559  if (parser.parseRParen())
560  return failure();
561 
562  auto *argsBegin = regionPrivateArgs.begin();
563  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
564  argsBegin + regionArgOffset + types.size());
565  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
566  prv.type = type;
567  }
568 
569  if (symbols) {
570  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
571  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
572  }
573 
574  if (byref)
575  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
576 
577  return success();
578 }
579 
580 static ParseResult parseBlockArgClause(
581  OpAsmParser &parser,
583  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
584  if (succeeded(parser.parseOptionalKeyword(keyword))) {
585  if (!mapArgs)
586  return failure();
587 
588  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
589  entryBlockArgs)))
590  return failure();
591  }
592  return success();
593 }
594 
595 static ParseResult parseBlockArgClause(
596  OpAsmParser &parser,
598  StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
599  if (succeeded(parser.parseOptionalKeyword(keyword))) {
600  if (!reductionArgs)
601  return failure();
602 
603  if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
604  reductionArgs->types, entryBlockArgs,
605  &reductionArgs->syms)))
606  return failure();
607  }
608  return success();
609 }
610 
611 static ParseResult parseBlockArgClause(
612  OpAsmParser &parser,
614  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
615  if (succeeded(parser.parseOptionalKeyword(keyword))) {
616  if (!reductionArgs)
617  return failure();
618 
619  if (failed(parseClauseWithRegionArgs(
620  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
621  &reductionArgs->syms, &reductionArgs->byref)))
622  return failure();
623  }
624  return success();
625 }
626 
627 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
628  AllRegionParseArgs args) {
630 
631  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
632  args.inReductionArgs)))
633  return parser.emitError(parser.getCurrentLocation())
634  << "invalid `in_reduction` format";
635 
636  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
637  args.mapArgs)))
638  return parser.emitError(parser.getCurrentLocation())
639  << "invalid `map_entries` format";
640 
641  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
642  args.privateArgs)))
643  return parser.emitError(parser.getCurrentLocation())
644  << "invalid `private` format";
645 
646  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
647  args.reductionArgs)))
648  return parser.emitError(parser.getCurrentLocation())
649  << "invalid `reduction` format";
650 
651  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
652  args.taskReductionArgs)))
653  return parser.emitError(parser.getCurrentLocation())
654  << "invalid `task_reduction` format";
655 
656  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
657  args.useDeviceAddrArgs)))
658  return parser.emitError(parser.getCurrentLocation())
659  << "invalid `use_device_addr` format";
660 
661  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
662  args.useDevicePtrArgs)))
663  return parser.emitError(parser.getCurrentLocation())
664  << "invalid `use_device_addr` format";
665 
666  return parser.parseRegion(region, entryBlockArgs);
667 }
668 
670  OpAsmParser &parser, Region &region,
672  SmallVectorImpl<Type> &inReductionTypes,
673  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
675  SmallVectorImpl<Type> &mapTypes,
677  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
678  AllRegionParseArgs args;
679  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
680  inReductionByref, inReductionSyms);
681  args.mapArgs.emplace(mapVars, mapTypes);
682  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
683  return parseBlockArgRegion(parser, region, args);
684 }
685 
686 static ParseResult parseInReductionPrivateRegion(
687  OpAsmParser &parser, Region &region,
689  SmallVectorImpl<Type> &inReductionTypes,
690  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
692  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
693  AllRegionParseArgs args;
694  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
695  inReductionByref, inReductionSyms);
696  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
697  return parseBlockArgRegion(parser, region, args);
698 }
699 
701  OpAsmParser &parser, Region &region,
703  SmallVectorImpl<Type> &inReductionTypes,
704  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
706  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
708  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
709  ArrayAttr &reductionSyms) {
710  AllRegionParseArgs args;
711  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
712  inReductionByref, inReductionSyms);
713  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
714  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
715  reductionSyms);
716  return parseBlockArgRegion(parser, region, args);
717 }
718 
719 static ParseResult parsePrivateRegion(
720  OpAsmParser &parser, Region &region,
722  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
723  AllRegionParseArgs args;
724  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
725  return parseBlockArgRegion(parser, region, args);
726 }
727 
728 static ParseResult parsePrivateReductionRegion(
729  OpAsmParser &parser, Region &region,
731  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
733  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
734  ArrayAttr &reductionSyms) {
735  AllRegionParseArgs args;
736  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
737  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
738  reductionSyms);
739  return parseBlockArgRegion(parser, region, args);
740 }
741 
742 static ParseResult parseTaskReductionRegion(
743  OpAsmParser &parser, Region &region,
745  SmallVectorImpl<Type> &taskReductionTypes,
746  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
747  AllRegionParseArgs args;
748  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
749  taskReductionByref, taskReductionSyms);
750  return parseBlockArgRegion(parser, region, args);
751 }
752 
754  OpAsmParser &parser, Region &region,
756  SmallVectorImpl<Type> &useDeviceAddrTypes,
758  SmallVectorImpl<Type> &useDevicePtrTypes) {
759  AllRegionParseArgs args;
760  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
761  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
762  return parseBlockArgRegion(parser, region, args);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Printers for operations including clauses that define entry block arguments.
767 //===----------------------------------------------------------------------===//
768 
769 namespace {
770 struct MapPrintArgs {
771  ValueRange vars;
772  TypeRange types;
773  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
774 };
775 struct PrivatePrintArgs {
776  ValueRange vars;
777  TypeRange types;
778  ArrayAttr syms;
779  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
780  : vars(vars), types(types), syms(syms) {}
781 };
782 struct ReductionPrintArgs {
783  ValueRange vars;
784  TypeRange types;
785  DenseBoolArrayAttr byref;
786  ArrayAttr syms;
787  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
788  ArrayAttr syms)
789  : vars(vars), types(types), byref(byref), syms(syms) {}
790 };
791 struct AllRegionPrintArgs {
792  std::optional<ReductionPrintArgs> inReductionArgs;
793  std::optional<MapPrintArgs> mapArgs;
794  std::optional<PrivatePrintArgs> privateArgs;
795  std::optional<ReductionPrintArgs> reductionArgs;
796  std::optional<ReductionPrintArgs> taskReductionArgs;
797  std::optional<MapPrintArgs> useDeviceAddrArgs;
798  std::optional<MapPrintArgs> useDevicePtrArgs;
799 };
800 } // namespace
801 
803  StringRef clauseName,
804  ValueRange argsSubrange,
805  ValueRange operands, TypeRange types,
806  ArrayAttr symbols = nullptr,
807  DenseBoolArrayAttr byref = nullptr) {
808  if (argsSubrange.empty())
809  return;
810 
811  p << clauseName << "(";
812 
813  if (!symbols) {
814  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
815  symbols = ArrayAttr::get(ctx, values);
816  }
817 
818  if (!byref) {
819  mlir::SmallVector<bool> values(operands.size(), false);
820  byref = DenseBoolArrayAttr::get(ctx, values);
821  }
822 
823  llvm::interleaveComma(
824  llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
825  [&p](auto t) {
826  auto [op, arg, sym, isByRef] = t;
827  if (isByRef)
828  p << "byref ";
829  if (sym)
830  p << sym << " ";
831  p << op << " -> " << arg;
832  });
833  p << " : ";
834  llvm::interleaveComma(types, p);
835  p << ") ";
836 }
837 
839  StringRef clauseName, ValueRange argsSubrange,
840  std::optional<MapPrintArgs> mapArgs) {
841  if (mapArgs)
842  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
843  mapArgs->types);
844 }
845 
847  StringRef clauseName, ValueRange argsSubrange,
848  std::optional<PrivatePrintArgs> privateArgs) {
849  if (privateArgs)
850  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
851  privateArgs->vars, privateArgs->types,
852  privateArgs->syms);
853 }
854 
855 static void
856 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
857  ValueRange argsSubrange,
858  std::optional<ReductionPrintArgs> reductionArgs) {
859  if (reductionArgs)
860  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
861  reductionArgs->vars, reductionArgs->types,
862  reductionArgs->syms, reductionArgs->byref);
863 }
864 
865 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
866  const AllRegionPrintArgs &args) {
867  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
868  MLIRContext *ctx = op->getContext();
869 
870  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
871  args.inReductionArgs);
872  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
873  args.mapArgs);
874  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
875  args.privateArgs);
876  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
877  args.reductionArgs);
878  printBlockArgClause(p, ctx, "task_reduction",
879  iface.getTaskReductionBlockArgs(),
880  args.taskReductionArgs);
881  printBlockArgClause(p, ctx, "use_device_addr",
882  iface.getUseDeviceAddrBlockArgs(),
883  args.useDeviceAddrArgs);
884  printBlockArgClause(p, ctx, "use_device_ptr",
885  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
886 
887  p.printRegion(region, /*printEntryBlockArgs=*/false);
888 }
889 
891  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
892  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893  ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894  ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
895  AllRegionPrintArgs args;
896  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
897  inReductionByref, inReductionSyms);
898  args.mapArgs.emplace(mapVars, mapTypes);
899  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
900  printBlockArgRegion(p, op, region, args);
901 }
902 
904  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
905  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
906  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
907  ArrayAttr privateSyms) {
908  AllRegionPrintArgs args;
909  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
910  inReductionByref, inReductionSyms);
911  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
912  printBlockArgRegion(p, op, region, args);
913 }
914 
916  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
917  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
918  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
919  ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
920  DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
921  AllRegionPrintArgs args;
922  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
923  inReductionByref, inReductionSyms);
924  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
925  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
926  reductionSyms);
927  printBlockArgRegion(p, op, region, args);
928 }
929 
930 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
931  ValueRange privateVars, TypeRange privateTypes,
932  ArrayAttr privateSyms) {
933  AllRegionPrintArgs args;
934  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
935  printBlockArgRegion(p, op, region, args);
936 }
937 
939  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
940  TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
941  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942  ArrayAttr reductionSyms) {
943  AllRegionPrintArgs args;
944  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
945  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
946  reductionSyms);
947  printBlockArgRegion(p, op, region, args);
948 }
949 
951  Region &region,
952  ValueRange taskReductionVars,
953  TypeRange taskReductionTypes,
954  DenseBoolArrayAttr taskReductionByref,
955  ArrayAttr taskReductionSyms) {
956  AllRegionPrintArgs args;
957  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
958  taskReductionByref, taskReductionSyms);
959  printBlockArgRegion(p, op, region, args);
960 }
961 
963  Region &region,
964  ValueRange useDeviceAddrVars,
965  TypeRange useDeviceAddrTypes,
966  ValueRange useDevicePtrVars,
967  TypeRange useDevicePtrTypes) {
968  AllRegionPrintArgs args;
969  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
970  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
971  printBlockArgRegion(p, op, region, args);
972 }
973 
974 /// Verifies Reduction Clause
975 static LogicalResult
976 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
977  OperandRange reductionVars,
978  std::optional<ArrayRef<bool>> reductionByref) {
979  if (!reductionVars.empty()) {
980  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
981  return op->emitOpError()
982  << "expected as many reduction symbol references "
983  "as reduction variables";
984  if (reductionByref && reductionByref->size() != reductionVars.size())
985  return op->emitError() << "expected as many reduction variable by "
986  "reference attributes as reduction variables";
987  } else {
988  if (reductionSyms)
989  return op->emitOpError() << "unexpected reduction symbol references";
990  return success();
991  }
992 
993  // TODO: The followings should be done in
994  // SymbolUserOpInterface::verifySymbolUses.
995  DenseSet<Value> accumulators;
996  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
997  Value accum = std::get<0>(args);
998 
999  if (!accumulators.insert(accum).second)
1000  return op->emitOpError() << "accumulator variable used more than once";
1001 
1002  Type varType = accum.getType();
1003  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1004  auto decl =
1005  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1006  if (!decl)
1007  return op->emitOpError() << "expected symbol reference " << symbolRef
1008  << " to point to a reduction declaration";
1009 
1010  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1011  return op->emitOpError()
1012  << "expected accumulator (" << varType
1013  << ") to be the same type as reduction declaration ("
1014  << decl.getAccumulatorType() << ")";
1015  }
1016 
1017  return success();
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // Parser, printer and verifier for Copyprivate
1022 //===----------------------------------------------------------------------===//
1023 
1024 /// copyprivate-entry-list ::= copyprivate-entry
1025 /// | copyprivate-entry-list `,` copyprivate-entry
1026 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1027 static ParseResult parseCopyprivate(
1028  OpAsmParser &parser,
1030  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1032  if (failed(parser.parseCommaSeparatedList([&]() {
1033  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1034  parser.parseArrow() ||
1035  parser.parseAttribute(symsVec.emplace_back()) ||
1036  parser.parseColonType(copyprivateTypes.emplace_back()))
1037  return failure();
1038  return success();
1039  })))
1040  return failure();
1041  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1042  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1043  return success();
1044 }
1045 
1046 /// Print Copyprivate clause
1048  OperandRange copyprivateVars,
1049  TypeRange copyprivateTypes,
1050  std::optional<ArrayAttr> copyprivateSyms) {
1051  if (!copyprivateSyms.has_value())
1052  return;
1053  llvm::interleaveComma(
1054  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1055  [&](const auto &args) {
1056  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1057  << std::get<2>(args);
1058  });
1059 }
1060 
1061 /// Verifies CopyPrivate Clause
1062 static LogicalResult
1064  std::optional<ArrayAttr> copyprivateSyms) {
1065  size_t copyprivateSymsSize =
1066  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1067  if (copyprivateSymsSize != copyprivateVars.size())
1068  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1069  << copyprivateVars.size()
1070  << ") and functions (= " << copyprivateSymsSize
1071  << "), both must be equal";
1072  if (!copyprivateSyms.has_value())
1073  return success();
1074 
1075  for (auto copyprivateVarAndSym :
1076  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1077  auto symbolRef =
1078  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1079  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1080  funcOp;
1081  if (mlir::func::FuncOp mlirFuncOp =
1082  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1083  symbolRef))
1084  funcOp = mlirFuncOp;
1085  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1086  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1087  op, symbolRef))
1088  funcOp = llvmFuncOp;
1089 
1090  auto getNumArguments = [&] {
1091  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1092  };
1093 
1094  auto getArgumentType = [&](unsigned i) {
1095  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1096  *funcOp);
1097  };
1098 
1099  if (!funcOp)
1100  return op->emitOpError() << "expected symbol reference " << symbolRef
1101  << " to point to a copy function";
1102 
1103  if (getNumArguments() != 2)
1104  return op->emitOpError()
1105  << "expected copy function " << symbolRef << " to have 2 operands";
1106 
1107  Type argTy = getArgumentType(0);
1108  if (argTy != getArgumentType(1))
1109  return op->emitOpError() << "expected copy function " << symbolRef
1110  << " arguments to have the same type";
1111 
1112  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1113  if (argTy != varType)
1114  return op->emitOpError()
1115  << "expected copy function arguments' type (" << argTy
1116  << ") to be the same as copyprivate variable's type (" << varType
1117  << ")";
1118  }
1119 
1120  return success();
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // Parser, printer and verifier for DependVarList
1125 //===----------------------------------------------------------------------===//
1126 
1127 /// depend-entry-list ::= depend-entry
1128 /// | depend-entry-list `,` depend-entry
1129 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1130 static ParseResult
1133  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1135  if (failed(parser.parseCommaSeparatedList([&]() {
1136  StringRef keyword;
1137  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1138  parser.parseOperand(dependVars.emplace_back()) ||
1139  parser.parseColonType(dependTypes.emplace_back()))
1140  return failure();
1141  if (std::optional<ClauseTaskDepend> keywordDepend =
1142  (symbolizeClauseTaskDepend(keyword)))
1143  kindsVec.emplace_back(
1144  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1145  else
1146  return failure();
1147  return success();
1148  })))
1149  return failure();
1150  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1151  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1152  return success();
1153 }
1154 
1155 /// Print Depend clause
1157  OperandRange dependVars, TypeRange dependTypes,
1158  std::optional<ArrayAttr> dependKinds) {
1159 
1160  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1161  if (i != 0)
1162  p << ", ";
1163  p << stringifyClauseTaskDepend(
1164  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1165  .getValue())
1166  << " -> " << dependVars[i] << " : " << dependTypes[i];
1167  }
1168 }
1169 
1170 /// Verifies Depend clause
1171 static LogicalResult verifyDependVarList(Operation *op,
1172  std::optional<ArrayAttr> dependKinds,
1173  OperandRange dependVars) {
1174  if (!dependVars.empty()) {
1175  if (!dependKinds || dependKinds->size() != dependVars.size())
1176  return op->emitOpError() << "expected as many depend values"
1177  " as depend variables";
1178  } else {
1179  if (dependKinds && !dependKinds->empty())
1180  return op->emitOpError() << "unexpected depend values";
1181  return success();
1182  }
1183 
1184  return success();
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1189 //===----------------------------------------------------------------------===//
1190 
1191 /// Parses a Synchronization Hint clause. The value of hint is an integer
1192 /// which is a combination of different hints from `omp_sync_hint_t`.
1193 ///
1194 /// hint-clause = `hint` `(` hint-value `)`
1195 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1196  IntegerAttr &hintAttr) {
1197  StringRef hintKeyword;
1198  int64_t hint = 0;
1199  if (succeeded(parser.parseOptionalKeyword("none"))) {
1200  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1201  return success();
1202  }
1203  auto parseKeyword = [&]() -> ParseResult {
1204  if (failed(parser.parseKeyword(&hintKeyword)))
1205  return failure();
1206  if (hintKeyword == "uncontended")
1207  hint |= 1;
1208  else if (hintKeyword == "contended")
1209  hint |= 2;
1210  else if (hintKeyword == "nonspeculative")
1211  hint |= 4;
1212  else if (hintKeyword == "speculative")
1213  hint |= 8;
1214  else
1215  return parser.emitError(parser.getCurrentLocation())
1216  << hintKeyword << " is not a valid hint";
1217  return success();
1218  };
1219  if (parser.parseCommaSeparatedList(parseKeyword))
1220  return failure();
1221  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1222  return success();
1223 }
1224 
1225 /// Prints a Synchronization Hint clause
1227  IntegerAttr hintAttr) {
1228  int64_t hint = hintAttr.getInt();
1229 
1230  if (hint == 0) {
1231  p << "none";
1232  return;
1233  }
1234 
1235  // Helper function to get n-th bit from the right end of `value`
1236  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1237 
1238  bool uncontended = bitn(hint, 0);
1239  bool contended = bitn(hint, 1);
1240  bool nonspeculative = bitn(hint, 2);
1241  bool speculative = bitn(hint, 3);
1242 
1243  SmallVector<StringRef> hints;
1244  if (uncontended)
1245  hints.push_back("uncontended");
1246  if (contended)
1247  hints.push_back("contended");
1248  if (nonspeculative)
1249  hints.push_back("nonspeculative");
1250  if (speculative)
1251  hints.push_back("speculative");
1252 
1253  llvm::interleaveComma(hints, p);
1254 }
1255 
1256 /// Verifies a synchronization hint clause
1257 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1258 
1259  // Helper function to get n-th bit from the right end of `value`
1260  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1261 
1262  bool uncontended = bitn(hint, 0);
1263  bool contended = bitn(hint, 1);
1264  bool nonspeculative = bitn(hint, 2);
1265  bool speculative = bitn(hint, 3);
1266 
1267  if (uncontended && contended)
1268  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1269  "omp_sync_hint_contended cannot be combined";
1270  if (nonspeculative && speculative)
1271  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1272  "omp_sync_hint_speculative cannot be combined.";
1273  return success();
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // Parser, printer and verifier for Target
1278 //===----------------------------------------------------------------------===//
1279 
1280 // Helper function to get bitwise AND of `value` and 'flag'
1281 uint64_t mapTypeToBitFlag(uint64_t value,
1282  llvm::omp::OpenMPOffloadMappingFlags flag) {
1283  return value & llvm::to_underlying(flag);
1284 }
1285 
1286 /// Parses a map_entries map type from a string format back into its numeric
1287 /// value.
1288 ///
1289 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
1290 /// `to` | `from` | `delete` `)` )+ `)` )
1291 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1292  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1293  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1294 
1295  // This simply verifies the correct keyword is read in, the
1296  // keyword itself is stored inside of the operation
1297  auto parseTypeAndMod = [&]() -> ParseResult {
1298  StringRef mapTypeMod;
1299  if (parser.parseKeyword(&mapTypeMod))
1300  return failure();
1301 
1302  if (mapTypeMod == "always")
1303  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1304 
1305  if (mapTypeMod == "implicit")
1306  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1307 
1308  if (mapTypeMod == "close")
1309  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1310 
1311  if (mapTypeMod == "present")
1312  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1313 
1314  if (mapTypeMod == "to")
1315  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1316 
1317  if (mapTypeMod == "from")
1318  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1319 
1320  if (mapTypeMod == "tofrom")
1321  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1322  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1323 
1324  if (mapTypeMod == "delete")
1325  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1326 
1327  return success();
1328  };
1329 
1330  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1331  return failure();
1332 
1333  mapType = parser.getBuilder().getIntegerAttr(
1334  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1335  llvm::to_underlying(mapTypeBits));
1336 
1337  return success();
1338 }
1339 
1340 /// Prints a map_entries map type from its numeric value out into its string
1341 /// format.
1343  IntegerAttr mapType) {
1344  uint64_t mapTypeBits = mapType.getUInt();
1345 
1346  bool emitAllocRelease = true;
1348 
1349  // handling of always, close, present placed at the beginning of the string
1350  // to aid readability
1351  if (mapTypeToBitFlag(mapTypeBits,
1352  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1353  mapTypeStrs.push_back("always");
1354  if (mapTypeToBitFlag(mapTypeBits,
1355  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1356  mapTypeStrs.push_back("implicit");
1357  if (mapTypeToBitFlag(mapTypeBits,
1358  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1359  mapTypeStrs.push_back("close");
1360  if (mapTypeToBitFlag(mapTypeBits,
1361  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1362  mapTypeStrs.push_back("present");
1363 
1364  // special handling of to/from/tofrom/delete and release/alloc, release +
1365  // alloc are the abscense of one of the other flags, whereas tofrom requires
1366  // both the to and from flag to be set.
1367  bool to = mapTypeToBitFlag(mapTypeBits,
1368  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1369  bool from = mapTypeToBitFlag(
1370  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1371  if (to && from) {
1372  emitAllocRelease = false;
1373  mapTypeStrs.push_back("tofrom");
1374  } else if (from) {
1375  emitAllocRelease = false;
1376  mapTypeStrs.push_back("from");
1377  } else if (to) {
1378  emitAllocRelease = false;
1379  mapTypeStrs.push_back("to");
1380  }
1381  if (mapTypeToBitFlag(mapTypeBits,
1382  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1383  emitAllocRelease = false;
1384  mapTypeStrs.push_back("delete");
1385  }
1386  if (emitAllocRelease)
1387  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1388 
1389  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1390  p << mapTypeStrs[i];
1391  if (i + 1 < mapTypeStrs.size()) {
1392  p << ", ";
1393  }
1394  }
1395 }
1396 
1397 static ParseResult parseMembersIndex(OpAsmParser &parser,
1398  ArrayAttr &membersIdx) {
1399  SmallVector<Attribute> values, memberIdxs;
1400 
1401  auto parseIndices = [&]() -> ParseResult {
1402  int64_t value;
1403  if (parser.parseInteger(value))
1404  return failure();
1405  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1406  APInt(64, value, /*isSigned=*/false)));
1407  return success();
1408  };
1409 
1410  do {
1411  if (failed(parser.parseLSquare()))
1412  return failure();
1413 
1414  if (parser.parseCommaSeparatedList(parseIndices))
1415  return failure();
1416 
1417  if (failed(parser.parseRSquare()))
1418  return failure();
1419 
1420  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1421  values.clear();
1422  } while (succeeded(parser.parseOptionalComma()));
1423 
1424  if (!memberIdxs.empty())
1425  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1426 
1427  return success();
1428 }
1429 
1430 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1431  ArrayAttr membersIdx) {
1432  if (!membersIdx)
1433  return;
1434 
1435  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1436  p << "[";
1437  auto memberIdx = cast<ArrayAttr>(v);
1438  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1439  p << cast<IntegerAttr>(v2).getInt();
1440  });
1441  p << "]";
1442  });
1443 }
1444 
1446  VariableCaptureKindAttr mapCaptureType) {
1447  std::string typeCapStr;
1448  llvm::raw_string_ostream typeCap(typeCapStr);
1449  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1450  typeCap << "ByRef";
1451  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1452  typeCap << "ByCopy";
1453  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1454  typeCap << "VLAType";
1455  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1456  typeCap << "This";
1457  p << typeCapStr;
1458 }
1459 
1460 static ParseResult parseCaptureType(OpAsmParser &parser,
1461  VariableCaptureKindAttr &mapCaptureType) {
1462  StringRef mapCaptureKey;
1463  if (parser.parseKeyword(&mapCaptureKey))
1464  return failure();
1465 
1466  if (mapCaptureKey == "This")
1467  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1468  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1469  if (mapCaptureKey == "ByRef")
1470  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1471  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1472  if (mapCaptureKey == "ByCopy")
1473  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1474  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1475  if (mapCaptureKey == "VLAType")
1476  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1477  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1478 
1479  return success();
1480 }
1481 
1482 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1485 
1486  for (auto mapOp : mapVars) {
1487  if (!mapOp.getDefiningOp())
1488  emitError(op->getLoc(), "missing map operation");
1489 
1490  if (auto mapInfoOp =
1491  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1492  if (!mapInfoOp.getMapType().has_value())
1493  emitError(op->getLoc(), "missing map type for map operand");
1494 
1495  if (!mapInfoOp.getMapCaptureType().has_value())
1496  emitError(op->getLoc(), "missing map capture type for map operand");
1497 
1498  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1499 
1500  bool to = mapTypeToBitFlag(
1501  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1502  bool from = mapTypeToBitFlag(
1503  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1504  bool del = mapTypeToBitFlag(
1505  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1506 
1507  bool always = mapTypeToBitFlag(
1508  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1509  bool close = mapTypeToBitFlag(
1510  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1511  bool implicit = mapTypeToBitFlag(
1512  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1513 
1514  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1515  return emitError(op->getLoc(),
1516  "to, from, tofrom and alloc map types are permitted");
1517 
1518  if (isa<TargetEnterDataOp>(op) && (from || del))
1519  return emitError(op->getLoc(), "to and alloc map types are permitted");
1520 
1521  if (isa<TargetExitDataOp>(op) && to)
1522  return emitError(op->getLoc(),
1523  "from, release and delete map types are permitted");
1524 
1525  if (isa<TargetUpdateOp>(op)) {
1526  if (del) {
1527  return emitError(op->getLoc(),
1528  "at least one of to or from map types must be "
1529  "specified, other map types are not permitted");
1530  }
1531 
1532  if (!to && !from) {
1533  return emitError(op->getLoc(),
1534  "at least one of to or from map types must be "
1535  "specified, other map types are not permitted");
1536  }
1537 
1538  auto updateVar = mapInfoOp.getVarPtr();
1539 
1540  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1541  (from && updateToVars.contains(updateVar))) {
1542  return emitError(
1543  op->getLoc(),
1544  "either to or from map types can be specified, not both");
1545  }
1546 
1547  if (always || close || implicit) {
1548  return emitError(
1549  op->getLoc(),
1550  "present, mapper and iterator map type modifiers are permitted");
1551  }
1552 
1553  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1554  }
1555  } else {
1556  emitError(op->getLoc(), "map argument is not a map entry operation");
1557  }
1558  }
1559 
1560  return success();
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // TargetDataOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1568  const TargetDataOperands &clauses) {
1569  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1570  clauses.mapVars, clauses.useDeviceAddrVars,
1571  clauses.useDevicePtrVars);
1572 }
1573 
1574 LogicalResult TargetDataOp::verify() {
1575  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1576  getUseDeviceAddrVars().empty()) {
1577  return ::emitError(this->getLoc(),
1578  "At least one of map, use_device_ptr_vars, or "
1579  "use_device_addr_vars operand must be present");
1580  }
1581  return verifyMapClause(*this, getMapVars());
1582 }
1583 
1584 //===----------------------------------------------------------------------===//
1585 // TargetEnterDataOp
1586 //===----------------------------------------------------------------------===//
1587 
1588 void TargetEnterDataOp::build(
1589  OpBuilder &builder, OperationState &state,
1590  const TargetEnterExitUpdateDataOperands &clauses) {
1591  MLIRContext *ctx = builder.getContext();
1592  TargetEnterDataOp::build(builder, state,
1593  makeArrayAttr(ctx, clauses.dependKinds),
1594  clauses.dependVars, clauses.device, clauses.ifExpr,
1595  clauses.mapVars, clauses.nowait);
1596 }
1597 
1598 LogicalResult TargetEnterDataOp::verify() {
1599  LogicalResult verifyDependVars =
1600  verifyDependVarList(*this, getDependKinds(), getDependVars());
1601  return failed(verifyDependVars) ? verifyDependVars
1602  : verifyMapClause(*this, getMapVars());
1603 }
1604 
1605 //===----------------------------------------------------------------------===//
1606 // TargetExitDataOp
1607 //===----------------------------------------------------------------------===//
1608 
1609 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1610  const TargetEnterExitUpdateDataOperands &clauses) {
1611  MLIRContext *ctx = builder.getContext();
1612  TargetExitDataOp::build(builder, state,
1613  makeArrayAttr(ctx, clauses.dependKinds),
1614  clauses.dependVars, clauses.device, clauses.ifExpr,
1615  clauses.mapVars, clauses.nowait);
1616 }
1617 
1618 LogicalResult TargetExitDataOp::verify() {
1619  LogicalResult verifyDependVars =
1620  verifyDependVarList(*this, getDependKinds(), getDependVars());
1621  return failed(verifyDependVars) ? verifyDependVars
1622  : verifyMapClause(*this, getMapVars());
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // TargetUpdateOp
1627 //===----------------------------------------------------------------------===//
1628 
1629 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1630  const TargetEnterExitUpdateDataOperands &clauses) {
1631  MLIRContext *ctx = builder.getContext();
1632  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1633  clauses.dependVars, clauses.device, clauses.ifExpr,
1634  clauses.mapVars, clauses.nowait);
1635 }
1636 
1637 LogicalResult TargetUpdateOp::verify() {
1638  LogicalResult verifyDependVars =
1639  verifyDependVarList(*this, getDependKinds(), getDependVars());
1640  return failed(verifyDependVars) ? verifyDependVars
1641  : verifyMapClause(*this, getMapVars());
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // TargetOp
1646 //===----------------------------------------------------------------------===//
1647 
1648 void TargetOp::build(OpBuilder &builder, OperationState &state,
1649  const TargetOperands &clauses) {
1650  MLIRContext *ctx = builder.getContext();
1651  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1652  // inReductionByref, inReductionSyms.
1653  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1654  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1655  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
1656  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1657  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1658  clauses.mapVars, clauses.nowait, clauses.privateVars,
1659  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
1660 }
1661 
1662 LogicalResult TargetOp::verify() {
1663  LogicalResult verifyDependVars =
1664  verifyDependVarList(*this, getDependKinds(), getDependVars());
1665  return failed(verifyDependVars) ? verifyDependVars
1666  : verifyMapClause(*this, getMapVars());
1667 }
1668 
1669 //===----------------------------------------------------------------------===//
1670 // ParallelOp
1671 //===----------------------------------------------------------------------===//
1672 
1673 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1674  ArrayRef<NamedAttribute> attributes) {
1675  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
1676  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
1677  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
1678  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
1679  /*reduction_vars=*/ValueRange(),
1680  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
1681  state.addAttributes(attributes);
1682 }
1683 
1684 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1685  const ParallelOperands &clauses) {
1686  MLIRContext *ctx = builder.getContext();
1687  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1688  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
1689  makeArrayAttr(ctx, clauses.privateSyms),
1690  clauses.procBindKind, clauses.reductionVars,
1691  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1692  makeArrayAttr(ctx, clauses.reductionSyms));
1693 }
1694 
1695 template <typename OpType>
1696 static LogicalResult verifyPrivateVarList(OpType &op) {
1697  auto privateVars = op.getPrivateVars();
1698  auto privateSyms = op.getPrivateSymsAttr();
1699 
1700  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
1701  return success();
1702 
1703  auto numPrivateVars = privateVars.size();
1704  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
1705 
1706  if (numPrivateVars != numPrivateSyms)
1707  return op.emitError() << "inconsistent number of private variables and "
1708  "privatizer op symbols, private vars: "
1709  << numPrivateVars
1710  << " vs. privatizer op symbols: " << numPrivateSyms;
1711 
1712  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
1713  Type varType = std::get<0>(privateVarInfo).getType();
1714  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1715  PrivateClauseOp privatizerOp =
1716  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
1717 
1718  if (privatizerOp == nullptr)
1719  return op.emitError() << "failed to lookup privatizer op with symbol: '"
1720  << privateSym << "'";
1721 
1722  Type privatizerType = privatizerOp.getType();
1723 
1724  if (varType != privatizerType)
1725  return op.emitError()
1726  << "type mismatch between a "
1727  << (privatizerOp.getDataSharingType() ==
1728  DataSharingClauseType::Private
1729  ? "private"
1730  : "firstprivate")
1731  << " variable and its privatizer op, var type: " << varType
1732  << " vs. privatizer op type: " << privatizerType;
1733  }
1734 
1735  return success();
1736 }
1737 
1738 LogicalResult ParallelOp::verify() {
1739  if (getAllocateVars().size() != getAllocatorVars().size())
1740  return emitError(
1741  "expected equal sizes for allocate and allocator variables");
1742 
1743  if (failed(verifyPrivateVarList(*this)))
1744  return failure();
1745 
1746  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1747  getReductionByref());
1748 }
1749 
1750 LogicalResult ParallelOp::verifyRegions() {
1751  auto distributeChildOps = getOps<DistributeOp>();
1752  if (!distributeChildOps.empty()) {
1753  if (!isComposite())
1754  return emitError()
1755  << "'omp.composite' attribute missing from composite operation";
1756 
1757  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
1758  Operation &distributeOp = **distributeChildOps.begin();
1759  for (Operation &childOp : getOps()) {
1760  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
1761  continue;
1762 
1763  if (!childOp.hasTrait<OpTrait::IsTerminator>())
1764  return emitError() << "unexpected OpenMP operation inside of composite "
1765  "'omp.parallel'";
1766  }
1767  } else if (isComposite()) {
1768  return emitError()
1769  << "'omp.composite' attribute present in non-composite operation";
1770  }
1771  return success();
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // TeamsOp
1776 //===----------------------------------------------------------------------===//
1777 
1779  while ((op = op->getParentOp()))
1780  if (isa<OpenMPDialect>(op->getDialect()))
1781  return false;
1782  return true;
1783 }
1784 
1785 void TeamsOp::build(OpBuilder &builder, OperationState &state,
1786  const TeamsOperands &clauses) {
1787  MLIRContext *ctx = builder.getContext();
1788  // TODO Store clauses in op: privateVars, privateSyms.
1789  TeamsOp::build(
1790  builder, state, clauses.allocateVars, clauses.allocatorVars,
1791  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
1792  /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
1793  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1794  makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
1795 }
1796 
1797 LogicalResult TeamsOp::verify() {
1798  // Check parent region
1799  // TODO If nested inside of a target region, also check that it does not
1800  // contain any statements, declarations or directives other than this
1801  // omp.teams construct. The issue is how to support the initialization of
1802  // this operation's own arguments (allow SSA values across omp.target?).
1803  Operation *op = getOperation();
1804  if (!isa<TargetOp>(op->getParentOp()) &&
1806  return emitError("expected to be nested inside of omp.target or not nested "
1807  "in any OpenMP dialect operations");
1808 
1809  // Check for num_teams clause restrictions
1810  if (auto numTeamsLowerBound = getNumTeamsLower()) {
1811  auto numTeamsUpperBound = getNumTeamsUpper();
1812  if (!numTeamsUpperBound)
1813  return emitError("expected num_teams upper bound to be defined if the "
1814  "lower bound is defined");
1815  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1816  return emitError(
1817  "expected num_teams upper bound and lower bound to be the same type");
1818  }
1819 
1820  // Check for allocate clause restrictions
1821  if (getAllocateVars().size() != getAllocatorVars().size())
1822  return emitError(
1823  "expected equal sizes for allocate and allocator variables");
1824 
1825  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1826  getReductionByref());
1827 }
1828 
1829 //===----------------------------------------------------------------------===//
1830 // SectionOp
1831 //===----------------------------------------------------------------------===//
1832 
1833 unsigned SectionOp::numPrivateBlockArgs() {
1834  return getParentOp().numPrivateBlockArgs();
1835 }
1836 
1837 unsigned SectionOp::numReductionBlockArgs() {
1838  return getParentOp().numReductionBlockArgs();
1839 }
1840 
1841 //===----------------------------------------------------------------------===//
1842 // SectionsOp
1843 //===----------------------------------------------------------------------===//
1844 
1845 void SectionsOp::build(OpBuilder &builder, OperationState &state,
1846  const SectionsOperands &clauses) {
1847  MLIRContext *ctx = builder.getContext();
1848  // TODO Store clauses in op: privateVars, privateSyms.
1849  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1850  clauses.nowait, /*private_vars=*/{},
1851  /*private_syms=*/nullptr, clauses.reductionVars,
1852  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1853  makeArrayAttr(ctx, clauses.reductionSyms));
1854 }
1855 
1856 LogicalResult SectionsOp::verify() {
1857  if (getAllocateVars().size() != getAllocatorVars().size())
1858  return emitError(
1859  "expected equal sizes for allocate and allocator variables");
1860 
1861  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1862  getReductionByref());
1863 }
1864 
1865 LogicalResult SectionsOp::verifyRegions() {
1866  for (auto &inst : *getRegion().begin()) {
1867  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1868  return emitOpError()
1869  << "expected omp.section op or terminator op inside region";
1870  }
1871  }
1872 
1873  return success();
1874 }
1875 
1876 //===----------------------------------------------------------------------===//
1877 // SingleOp
1878 //===----------------------------------------------------------------------===//
1879 
1880 void SingleOp::build(OpBuilder &builder, OperationState &state,
1881  const SingleOperands &clauses) {
1882  MLIRContext *ctx = builder.getContext();
1883  // TODO Store clauses in op: privateVars, privateSyms.
1884  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1885  clauses.copyprivateVars,
1886  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
1887  /*private_vars=*/{}, /*private_syms=*/nullptr);
1888 }
1889 
1890 LogicalResult SingleOp::verify() {
1891  // Check for allocate clause restrictions
1892  if (getAllocateVars().size() != getAllocatorVars().size())
1893  return emitError(
1894  "expected equal sizes for allocate and allocator variables");
1895 
1896  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
1897  getCopyprivateSyms());
1898 }
1899 
1900 //===----------------------------------------------------------------------===//
1901 // LoopWrapperInterface
1902 //===----------------------------------------------------------------------===//
1903 
1904 LogicalResult LoopWrapperInterface::verifyImpl() {
1905  Operation *op = this->getOperation();
1906  if (!op->hasTrait<OpTrait::NoTerminator>() ||
1908  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
1909  "and `SingleBlock` traits";
1910 
1911  if (op->getNumRegions() != 1)
1912  return emitOpError() << "loop wrapper does not contain exactly one region";
1913 
1914  Region &region = op->getRegion(0);
1915  if (range_size(region.getOps()) != 1)
1916  return emitOpError()
1917  << "loop wrapper does not contain exactly one nested op";
1918 
1919  Operation &firstOp = *region.op_begin();
1920  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1921  return emitOpError() << "op nested in loop wrapper is not another loop "
1922  "wrapper or `omp.loop_nest`";
1923 
1924  return success();
1925 }
1926 
1927 //===----------------------------------------------------------------------===//
1928 // LoopOp
1929 //===----------------------------------------------------------------------===//
1930 
1931 void LoopOp::build(OpBuilder &builder, OperationState &state,
1932  const LoopOperands &clauses) {
1933  MLIRContext *ctx = builder.getContext();
1934 
1935  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
1936  makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
1937  clauses.orderMod, clauses.reductionVars,
1938  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1939  makeArrayAttr(ctx, clauses.reductionSyms));
1940 }
1941 
1942 LogicalResult LoopOp::verify() {
1943  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1944  getReductionByref());
1945 }
1946 
1947 LogicalResult LoopOp::verifyRegions() {
1948  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
1949  getNestedWrapper())
1950  return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
1951 
1952  return success();
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // WsloopOp
1957 //===----------------------------------------------------------------------===//
1958 
1959 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1960  ArrayRef<NamedAttribute> attributes) {
1961  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1962  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1963  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
1964  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
1965  /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
1966  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
1967  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
1968  /*schedule_simd=*/false);
1969  state.addAttributes(attributes);
1970 }
1971 
1972 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1973  const WsloopOperands &clauses) {
1974  MLIRContext *ctx = builder.getContext();
1975  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
1976  // privateSyms.
1977  WsloopOp::build(
1978  builder, state,
1979  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
1980  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
1981  clauses.ordered, /*private_vars=*/{}, /*private_syms=*/nullptr,
1982  clauses.reductionVars,
1983  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1984  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
1985  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
1986 }
1987 
1988 LogicalResult WsloopOp::verify() {
1989  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1990  getReductionByref());
1991 }
1992 
1993 LogicalResult WsloopOp::verifyRegions() {
1994  bool isCompositeChildLeaf =
1995  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1996 
1997  if (LoopWrapperInterface nested = getNestedWrapper()) {
1998  if (!isComposite())
1999  return emitError()
2000  << "'omp.composite' attribute missing from composite wrapper";
2001 
2002  // Check for the allowed leaf constructs that may appear in a composite
2003  // construct directly after DO/FOR.
2004  if (!isa<SimdOp>(nested))
2005  return emitError() << "only supported nested wrapper is 'omp.simd'";
2006 
2007  } else if (isComposite() && !isCompositeChildLeaf) {
2008  return emitError()
2009  << "'omp.composite' attribute present in non-composite wrapper";
2010  } else if (!isComposite() && isCompositeChildLeaf) {
2011  return emitError()
2012  << "'omp.composite' attribute missing from composite wrapper";
2013  }
2014 
2015  return success();
2016 }
2017 
2018 //===----------------------------------------------------------------------===//
2019 // Simd construct [2.9.3.1]
2020 //===----------------------------------------------------------------------===//
2021 
2022 void SimdOp::build(OpBuilder &builder, OperationState &state,
2023  const SimdOperands &clauses) {
2024  MLIRContext *ctx = builder.getContext();
2025  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2026  // privateSyms.
2027  SimdOp::build(builder, state, clauses.alignedVars,
2028  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2029  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2030  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2031  /*private_vars=*/{}, /*private_syms=*/nullptr,
2032  clauses.reductionVars,
2033  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2034  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2035  clauses.simdlen);
2036 }
2037 
2038 LogicalResult SimdOp::verify() {
2039  if (getSimdlen().has_value() && getSafelen().has_value() &&
2040  getSimdlen().value() > getSafelen().value())
2041  return emitOpError()
2042  << "simdlen clause and safelen clause are both present, but the "
2043  "simdlen value is not less than or equal to safelen value";
2044 
2045  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2046  return failure();
2047 
2048  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2049  return failure();
2050 
2051  bool isCompositeChildLeaf =
2052  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2053 
2054  if (!isComposite() && isCompositeChildLeaf)
2055  return emitError()
2056  << "'omp.composite' attribute missing from composite wrapper";
2057 
2058  if (isComposite() && !isCompositeChildLeaf)
2059  return emitError()
2060  << "'omp.composite' attribute present in non-composite wrapper";
2061 
2062  return success();
2063 }
2064 
2065 LogicalResult SimdOp::verifyRegions() {
2066  if (getNestedWrapper())
2067  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2068 
2069  return success();
2070 }
2071 
2072 //===----------------------------------------------------------------------===//
2073 // Distribute construct [2.9.4.1]
2074 //===----------------------------------------------------------------------===//
2075 
2076 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2077  const DistributeOperands &clauses) {
2078  DistributeOp::build(builder, state, clauses.allocateVars,
2079  clauses.allocatorVars, clauses.distScheduleStatic,
2080  clauses.distScheduleChunkSize, clauses.order,
2081  clauses.orderMod, clauses.privateVars,
2082  makeArrayAttr(builder.getContext(), clauses.privateSyms));
2083 }
2084 
2085 LogicalResult DistributeOp::verify() {
2086  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2087  return emitOpError() << "chunk size set without "
2088  "dist_schedule_static being present";
2089 
2090  if (getAllocateVars().size() != getAllocatorVars().size())
2091  return emitError(
2092  "expected equal sizes for allocate and allocator variables");
2093 
2094  return success();
2095 }
2096 
2097 LogicalResult DistributeOp::verifyRegions() {
2098  if (LoopWrapperInterface nested = getNestedWrapper()) {
2099  if (!isComposite())
2100  return emitError()
2101  << "'omp.composite' attribute missing from composite wrapper";
2102  // Check for the allowed leaf constructs that may appear in a composite
2103  // construct directly after DISTRIBUTE.
2104  if (isa<WsloopOp>(nested)) {
2105  if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
2106  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2107  "when 'omp.parallel' is the direct parent";
2108  } else if (!isa<SimdOp>(nested))
2109  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2110  "'omp.wsloop'";
2111  } else if (isComposite()) {
2112  return emitError()
2113  << "'omp.composite' attribute present in non-composite wrapper";
2114  }
2115 
2116  return success();
2117 }
2118 
2119 //===----------------------------------------------------------------------===//
2120 // DeclareReductionOp
2121 //===----------------------------------------------------------------------===//
2122 
2123 LogicalResult DeclareReductionOp::verifyRegions() {
2124  if (!getAllocRegion().empty()) {
2125  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2126  if (yieldOp.getResults().size() != 1 ||
2127  yieldOp.getResults().getTypes()[0] != getType())
2128  return emitOpError() << "expects alloc region to yield a value "
2129  "of the reduction type";
2130  }
2131  }
2132 
2133  if (getInitializerRegion().empty())
2134  return emitOpError() << "expects non-empty initializer region";
2135  Block &initializerEntryBlock = getInitializerRegion().front();
2136 
2137  if (initializerEntryBlock.getNumArguments() == 1) {
2138  if (!getAllocRegion().empty())
2139  return emitOpError() << "expects two arguments to the initializer region "
2140  "when an allocation region is used";
2141  } else if (initializerEntryBlock.getNumArguments() == 2) {
2142  if (getAllocRegion().empty())
2143  return emitOpError() << "expects one argument to the initializer region "
2144  "when no allocation region is used";
2145  } else {
2146  return emitOpError()
2147  << "expects one or two arguments to the initializer region";
2148  }
2149 
2150  for (mlir::Value arg : initializerEntryBlock.getArguments())
2151  if (arg.getType() != getType())
2152  return emitOpError() << "expects initializer region argument to match "
2153  "the reduction type";
2154 
2155  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2156  if (yieldOp.getResults().size() != 1 ||
2157  yieldOp.getResults().getTypes()[0] != getType())
2158  return emitOpError() << "expects initializer region to yield a value "
2159  "of the reduction type";
2160  }
2161 
2162  if (getReductionRegion().empty())
2163  return emitOpError() << "expects non-empty reduction region";
2164  Block &reductionEntryBlock = getReductionRegion().front();
2165  if (reductionEntryBlock.getNumArguments() != 2 ||
2166  reductionEntryBlock.getArgumentTypes()[0] !=
2167  reductionEntryBlock.getArgumentTypes()[1] ||
2168  reductionEntryBlock.getArgumentTypes()[0] != getType())
2169  return emitOpError() << "expects reduction region with two arguments of "
2170  "the reduction type";
2171  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2172  if (yieldOp.getResults().size() != 1 ||
2173  yieldOp.getResults().getTypes()[0] != getType())
2174  return emitOpError() << "expects reduction region to yield a value "
2175  "of the reduction type";
2176  }
2177 
2178  if (!getAtomicReductionRegion().empty()) {
2179  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2180  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2181  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2182  atomicReductionEntryBlock.getArgumentTypes()[1])
2183  return emitOpError() << "expects atomic reduction region with two "
2184  "arguments of the same type";
2185  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2186  atomicReductionEntryBlock.getArgumentTypes()[0]);
2187  if (!ptrType ||
2188  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2189  return emitOpError() << "expects atomic reduction region arguments to "
2190  "be accumulators containing the reduction type";
2191  }
2192 
2193  if (getCleanupRegion().empty())
2194  return success();
2195  Block &cleanupEntryBlock = getCleanupRegion().front();
2196  if (cleanupEntryBlock.getNumArguments() != 1 ||
2197  cleanupEntryBlock.getArgument(0).getType() != getType())
2198  return emitOpError() << "expects cleanup region with one argument "
2199  "of the reduction type";
2200 
2201  return success();
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // TaskOp
2206 //===----------------------------------------------------------------------===//
2207 
2208 void TaskOp::build(OpBuilder &builder, OperationState &state,
2209  const TaskOperands &clauses) {
2210  MLIRContext *ctx = builder.getContext();
2211  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2212  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2213  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2214  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2215  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2216  clauses.priority, /*private_vars=*/clauses.privateVars,
2217  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2218  clauses.untied);
2219 }
2220 
2221 LogicalResult TaskOp::verify() {
2222  LogicalResult verifyDependVars =
2223  verifyDependVarList(*this, getDependKinds(), getDependVars());
2224  return failed(verifyDependVars)
2225  ? verifyDependVars
2226  : verifyReductionVarList(*this, getInReductionSyms(),
2227  getInReductionVars(),
2228  getInReductionByref());
2229 }
2230 
2231 //===----------------------------------------------------------------------===//
2232 // TaskgroupOp
2233 //===----------------------------------------------------------------------===//
2234 
2235 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2236  const TaskgroupOperands &clauses) {
2237  MLIRContext *ctx = builder.getContext();
2238  TaskgroupOp::build(builder, state, clauses.allocateVars,
2239  clauses.allocatorVars, clauses.taskReductionVars,
2240  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2241  makeArrayAttr(ctx, clauses.taskReductionSyms));
2242 }
2243 
2244 LogicalResult TaskgroupOp::verify() {
2245  return verifyReductionVarList(*this, getTaskReductionSyms(),
2246  getTaskReductionVars(),
2247  getTaskReductionByref());
2248 }
2249 
2250 //===----------------------------------------------------------------------===//
2251 // TaskloopOp
2252 //===----------------------------------------------------------------------===//
2253 
2254 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2255  const TaskloopOperands &clauses) {
2256  MLIRContext *ctx = builder.getContext();
2257  // TODO Store clauses in op: privateVars, privateSyms.
2258  TaskloopOp::build(
2259  builder, state, clauses.allocateVars, clauses.allocatorVars,
2260  clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2261  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2262  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2263  clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2264  /*private_syms=*/nullptr, clauses.reductionVars,
2265  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2266  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2267 }
2268 
2269 SmallVector<Value> TaskloopOp::getAllReductionVars() {
2270  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
2271  getInReductionVars().end());
2272  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2273  getReductionVars().end());
2274  return allReductionNvars;
2275 }
2276 
2277 LogicalResult TaskloopOp::verify() {
2278  if (getAllocateVars().size() != getAllocatorVars().size())
2279  return emitError(
2280  "expected equal sizes for allocate and allocator variables");
2281  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2282  getReductionVars(), getReductionByref())) ||
2283  failed(verifyReductionVarList(*this, getInReductionSyms(),
2284  getInReductionVars(),
2285  getInReductionByref())))
2286  return failure();
2287 
2288  if (!getReductionVars().empty() && getNogroup())
2289  return emitError("if a reduction clause is present on the taskloop "
2290  "directive, the nogroup clause must not be specified");
2291  for (auto var : getReductionVars()) {
2292  if (llvm::is_contained(getInReductionVars(), var))
2293  return emitError("the same list item cannot appear in both a reduction "
2294  "and an in_reduction clause");
2295  }
2296 
2297  if (getGrainsize() && getNumTasks()) {
2298  return emitError(
2299  "the grainsize clause and num_tasks clause are mutually exclusive and "
2300  "may not appear on the same taskloop directive");
2301  }
2302 
2303  return success();
2304 }
2305 
2306 LogicalResult TaskloopOp::verifyRegions() {
2307  if (LoopWrapperInterface nested = getNestedWrapper()) {
2308  if (!isComposite())
2309  return emitError()
2310  << "'omp.composite' attribute missing from composite wrapper";
2311 
2312  // Check for the allowed leaf constructs that may appear in a composite
2313  // construct directly after TASKLOOP.
2314  if (!isa<SimdOp>(nested))
2315  return emitError() << "only supported nested wrapper is 'omp.simd'";
2316  } else if (isComposite()) {
2317  return emitError()
2318  << "'omp.composite' attribute present in non-composite wrapper";
2319  }
2320 
2321  return success();
2322 }
2323 
2324 //===----------------------------------------------------------------------===//
2325 // LoopNestOp
2326 //===----------------------------------------------------------------------===//
2327 
2328 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2329  // Parse an opening `(` followed by induction variables followed by `)`
2332  Type loopVarType;
2334  parser.parseColonType(loopVarType) ||
2335  // Parse loop bounds.
2336  parser.parseEqual() ||
2337  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2338  parser.parseKeyword("to") ||
2339  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2340  return failure();
2341 
2342  for (auto &iv : ivs)
2343  iv.type = loopVarType;
2344 
2345  // Parse "inclusive" flag.
2346  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2347  result.addAttribute("loop_inclusive",
2348  UnitAttr::get(parser.getBuilder().getContext()));
2349 
2350  // Parse step values.
2352  if (parser.parseKeyword("step") ||
2353  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2354  return failure();
2355 
2356  // Parse the body.
2357  Region *region = result.addRegion();
2358  if (parser.parseRegion(*region, ivs))
2359  return failure();
2360 
2361  // Resolve operands.
2362  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2363  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2364  parser.resolveOperands(steps, loopVarType, result.operands))
2365  return failure();
2366 
2367  // Parse the optional attribute list.
2368  return parser.parseOptionalAttrDict(result.attributes);
2369 }
2370 
2372  Region &region = getRegion();
2373  auto args = region.getArguments();
2374  p << " (" << args << ") : " << args[0].getType() << " = ("
2375  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2376  if (getLoopInclusive())
2377  p << "inclusive ";
2378  p << "step (" << getLoopSteps() << ") ";
2379  p.printRegion(region, /*printEntryBlockArgs=*/false);
2380 }
2381 
2382 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2383  const LoopNestOperands &clauses) {
2384  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2385  clauses.loopUpperBounds, clauses.loopSteps,
2386  clauses.loopInclusive);
2387 }
2388 
2389 LogicalResult LoopNestOp::verify() {
2390  if (getLoopLowerBounds().empty())
2391  return emitOpError() << "must represent at least one loop";
2392 
2393  if (getLoopLowerBounds().size() != getIVs().size())
2394  return emitOpError() << "number of range arguments and IVs do not match";
2395 
2396  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2397  if (lb.getType() != iv.getType())
2398  return emitOpError()
2399  << "range argument type does not match corresponding IV type";
2400  }
2401 
2402  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2403  return emitOpError() << "expects parent op to be a loop wrapper";
2404 
2405  return success();
2406 }
2407 
2408 void LoopNestOp::gatherWrappers(
2410  Operation *parent = (*this)->getParentOp();
2411  while (auto wrapper =
2412  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2413  wrappers.push_back(wrapper);
2414  parent = parent->getParentOp();
2415  }
2416 }
2417 
2418 //===----------------------------------------------------------------------===//
2419 // Critical construct (2.17.1)
2420 //===----------------------------------------------------------------------===//
2421 
2422 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2423  const CriticalDeclareOperands &clauses) {
2424  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2425 }
2426 
2427 LogicalResult CriticalDeclareOp::verify() {
2428  return verifySynchronizationHint(*this, getHint());
2429 }
2430 
2431 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2432  if (getNameAttr()) {
2433  SymbolRefAttr symbolRef = getNameAttr();
2434  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2435  *this, symbolRef);
2436  if (!decl) {
2437  return emitOpError() << "expected symbol reference " << symbolRef
2438  << " to point to a critical declaration";
2439  }
2440  }
2441 
2442  return success();
2443 }
2444 
2445 //===----------------------------------------------------------------------===//
2446 // Ordered construct
2447 //===----------------------------------------------------------------------===//
2448 
2449 static LogicalResult verifyOrderedParent(Operation &op) {
2450  bool hasRegion = op.getNumRegions() > 0;
2451  auto loopOp = op.getParentOfType<LoopNestOp>();
2452  if (!loopOp) {
2453  if (hasRegion)
2454  return success();
2455 
2456  // TODO: Consider if this needs to be the case only for the standalone
2457  // variant of the ordered construct.
2458  return op.emitOpError() << "must be nested inside of a loop";
2459  }
2460 
2461  Operation *wrapper = loopOp->getParentOp();
2462  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2463  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2464  if (!orderedAttr)
2465  return op.emitOpError() << "the enclosing worksharing-loop region must "
2466  "have an ordered clause";
2467 
2468  if (hasRegion && orderedAttr.getInt() != 0)
2469  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2470  "have a parameter present";
2471 
2472  if (!hasRegion && orderedAttr.getInt() == 0)
2473  return op.emitOpError() << "the enclosing loop's ordered clause must "
2474  "have a parameter present";
2475  } else if (!isa<SimdOp>(wrapper)) {
2476  return op.emitOpError() << "must be nested inside of a worksharing, simd "
2477  "or worksharing simd loop";
2478  }
2479  return success();
2480 }
2481 
2482 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2483  const OrderedOperands &clauses) {
2484  OrderedOp::build(builder, state, clauses.doacrossDependType,
2485  clauses.doacrossNumLoops, clauses.doacrossDependVars);
2486 }
2487 
2488 LogicalResult OrderedOp::verify() {
2489  if (failed(verifyOrderedParent(**this)))
2490  return failure();
2491 
2492  auto wrapper = (*this)->getParentOfType<WsloopOp>();
2493  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2494  return emitOpError() << "number of variables in depend clause does not "
2495  << "match number of iteration variables in the "
2496  << "doacross loop";
2497 
2498  return success();
2499 }
2500 
2501 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2502  const OrderedRegionOperands &clauses) {
2503  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2504 }
2505 
2506 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
2507 
2508 //===----------------------------------------------------------------------===//
2509 // TaskwaitOp
2510 //===----------------------------------------------------------------------===//
2511 
2512 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2513  const TaskwaitOperands &clauses) {
2514  // TODO Store clauses in op: dependKinds, dependVars, nowait.
2515  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
2516  /*depend_vars=*/{}, /*nowait=*/nullptr);
2517 }
2518 
2519 //===----------------------------------------------------------------------===//
2520 // Verifier for AtomicReadOp
2521 //===----------------------------------------------------------------------===//
2522 
2523 LogicalResult AtomicReadOp::verify() {
2524  if (verifyCommon().failed())
2525  return mlir::failure();
2526 
2527  if (auto mo = getMemoryOrder()) {
2528  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2529  *mo == ClauseMemoryOrderKind::Release) {
2530  return emitError(
2531  "memory-order must not be acq_rel or release for atomic reads");
2532  }
2533  }
2534  return verifySynchronizationHint(*this, getHint());
2535 }
2536 
2537 //===----------------------------------------------------------------------===//
2538 // Verifier for AtomicWriteOp
2539 //===----------------------------------------------------------------------===//
2540 
2541 LogicalResult AtomicWriteOp::verify() {
2542  if (verifyCommon().failed())
2543  return mlir::failure();
2544 
2545  if (auto mo = getMemoryOrder()) {
2546  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2547  *mo == ClauseMemoryOrderKind::Acquire) {
2548  return emitError(
2549  "memory-order must not be acq_rel or acquire for atomic writes");
2550  }
2551  }
2552  return verifySynchronizationHint(*this, getHint());
2553 }
2554 
2555 //===----------------------------------------------------------------------===//
2556 // Verifier for AtomicUpdateOp
2557 //===----------------------------------------------------------------------===//
2558 
2559 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2560  PatternRewriter &rewriter) {
2561  if (op.isNoOp()) {
2562  rewriter.eraseOp(op);
2563  return success();
2564  }
2565  if (Value writeVal = op.getWriteOpVal()) {
2566  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
2567  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2568  return success();
2569  }
2570  return failure();
2571 }
2572 
2573 LogicalResult AtomicUpdateOp::verify() {
2574  if (verifyCommon().failed())
2575  return mlir::failure();
2576 
2577  if (auto mo = getMemoryOrder()) {
2578  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2579  *mo == ClauseMemoryOrderKind::Acquire) {
2580  return emitError(
2581  "memory-order must not be acq_rel or acquire for atomic updates");
2582  }
2583  }
2584 
2585  return verifySynchronizationHint(*this, getHint());
2586 }
2587 
2588 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2589 
2590 //===----------------------------------------------------------------------===//
2591 // Verifier for AtomicCaptureOp
2592 //===----------------------------------------------------------------------===//
2593 
2594 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2595  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2596  return op;
2597  return dyn_cast<AtomicReadOp>(getSecondOp());
2598 }
2599 
2600 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2601  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2602  return op;
2603  return dyn_cast<AtomicWriteOp>(getSecondOp());
2604 }
2605 
2606 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2607  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2608  return op;
2609  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2610 }
2611 
2612 LogicalResult AtomicCaptureOp::verify() {
2613  return verifySynchronizationHint(*this, getHint());
2614 }
2615 
2616 LogicalResult AtomicCaptureOp::verifyRegions() {
2617  if (verifyRegionsCommon().failed())
2618  return mlir::failure();
2619 
2620  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
2621  return emitOpError(
2622  "operations inside capture region must not have hint clause");
2623 
2624  if (getFirstOp()->getAttr("memory_order") ||
2625  getSecondOp()->getAttr("memory_order"))
2626  return emitOpError(
2627  "operations inside capture region must not have memory_order clause");
2628  return success();
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // CancelOp
2633 //===----------------------------------------------------------------------===//
2634 
2635 void CancelOp::build(OpBuilder &builder, OperationState &state,
2636  const CancelOperands &clauses) {
2637  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
2638 }
2639 
2640 LogicalResult CancelOp::verify() {
2641  ClauseCancellationConstructType cct = getCancelDirective();
2642  Operation *parentOp = (*this)->getParentOp();
2643 
2644  if (!parentOp) {
2645  return emitOpError() << "must be used within a region supporting "
2646  "cancel directive";
2647  }
2648 
2649  if ((cct == ClauseCancellationConstructType::Parallel) &&
2650  !isa<ParallelOp>(parentOp)) {
2651  return emitOpError() << "cancel parallel must appear "
2652  << "inside a parallel region";
2653  }
2654  if (cct == ClauseCancellationConstructType::Loop) {
2655  auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2656  auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2657  loopOp ? loopOp->getParentOp() : nullptr);
2658 
2659  if (!wsloopOp) {
2660  return emitOpError()
2661  << "cancel loop must appear inside a worksharing-loop region";
2662  }
2663  if (wsloopOp.getNowaitAttr()) {
2664  return emitError() << "A worksharing construct that is canceled "
2665  << "must not have a nowait clause";
2666  }
2667  if (wsloopOp.getOrderedAttr()) {
2668  return emitError() << "A worksharing construct that is canceled "
2669  << "must not have an ordered clause";
2670  }
2671 
2672  } else if (cct == ClauseCancellationConstructType::Sections) {
2673  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2674  return emitOpError() << "cancel sections must appear "
2675  << "inside a sections region";
2676  }
2677  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2678  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2679  return emitError() << "A sections construct that is canceled "
2680  << "must not have a nowait clause";
2681  }
2682  }
2683  // TODO : Add more when we support taskgroup.
2684  return success();
2685 }
2686 
2687 //===----------------------------------------------------------------------===//
2688 // CancellationPointOp
2689 //===----------------------------------------------------------------------===//
2690 
2691 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
2692  const CancellationPointOperands &clauses) {
2693  CancellationPointOp::build(builder, state, clauses.cancelDirective);
2694 }
2695 
2696 LogicalResult CancellationPointOp::verify() {
2697  ClauseCancellationConstructType cct = getCancelDirective();
2698  Operation *parentOp = (*this)->getParentOp();
2699 
2700  if (!parentOp) {
2701  return emitOpError() << "must be used within a region supporting "
2702  "cancellation point directive";
2703  }
2704 
2705  if ((cct == ClauseCancellationConstructType::Parallel) &&
2706  !(isa<ParallelOp>(parentOp))) {
2707  return emitOpError() << "cancellation point parallel must appear "
2708  << "inside a parallel region";
2709  }
2710  if ((cct == ClauseCancellationConstructType::Loop) &&
2711  (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
2712  return emitOpError() << "cancellation point loop must appear "
2713  << "inside a worksharing-loop region";
2714  }
2715  if ((cct == ClauseCancellationConstructType::Sections) &&
2716  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2717  return emitOpError() << "cancellation point sections must appear "
2718  << "inside a sections region";
2719  }
2720  // TODO : Add more when we support taskgroup.
2721  return success();
2722 }
2723 
2724 //===----------------------------------------------------------------------===//
2725 // MapBoundsOp
2726 //===----------------------------------------------------------------------===//
2727 
2728 LogicalResult MapBoundsOp::verify() {
2729  auto extent = getExtent();
2730  auto upperbound = getUpperBound();
2731  if (!extent && !upperbound)
2732  return emitError("expected extent or upperbound.");
2733  return success();
2734 }
2735 
2736 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2737  TypeRange /*result_types*/, StringAttr symName,
2738  TypeAttr type) {
2739  PrivateClauseOp::build(
2740  odsBuilder, odsState, symName, type,
2742  DataSharingClauseType::Private));
2743 }
2744 
2745 LogicalResult PrivateClauseOp::verifyRegions() {
2746  Type symType = getType();
2747 
2748  auto verifyTerminator = [&](Operation *terminator,
2749  bool yieldsValue) -> LogicalResult {
2750  if (!terminator->getBlock()->getSuccessors().empty())
2751  return success();
2752 
2753  if (!llvm::isa<YieldOp>(terminator))
2754  return mlir::emitError(terminator->getLoc())
2755  << "expected exit block terminator to be an `omp.yield` op.";
2756 
2757  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2758  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2759 
2760  if (!yieldsValue) {
2761  if (yieldedTypes.empty())
2762  return success();
2763 
2764  return mlir::emitError(terminator->getLoc())
2765  << "Did not expect any values to be yielded.";
2766  }
2767 
2768  if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2769  return success();
2770 
2771  auto error = mlir::emitError(yieldOp.getLoc())
2772  << "Invalid yielded value. Expected type: " << symType
2773  << ", got: ";
2774 
2775  if (yieldedTypes.empty())
2776  error << "None";
2777  else
2778  error << yieldedTypes;
2779 
2780  return error;
2781  };
2782 
2783  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
2784  StringRef regionName,
2785  bool yieldsValue) -> LogicalResult {
2786  assert(!region.empty());
2787 
2788  if (region.getNumArguments() != expectedNumArgs)
2789  return mlir::emitError(region.getLoc())
2790  << "`" << regionName << "`: "
2791  << "expected " << expectedNumArgs
2792  << " region arguments, got: " << region.getNumArguments();
2793 
2794  for (Block &block : region) {
2795  // MLIR will verify the absence of the terminator for us.
2796  if (!block.mightHaveTerminator())
2797  continue;
2798 
2799  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2800  return failure();
2801  }
2802 
2803  return success();
2804  };
2805 
2806  if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc",
2807  /*yieldsValue=*/true)))
2808  return failure();
2809 
2810  DataSharingClauseType dsType = getDataSharingType();
2811 
2812  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2813  return emitError("`private` clauses require only an `alloc` region.");
2814 
2815  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2816  return emitError(
2817  "`firstprivate` clauses require both `alloc` and `copy` regions.");
2818 
2819  if (dsType == DataSharingClauseType::FirstPrivate &&
2820  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
2821  /*yieldsValue=*/true)))
2822  return failure();
2823 
2824  if (!getDeallocRegion().empty() &&
2825  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
2826  /*yieldsValue=*/false)))
2827  return failure();
2828 
2829  return success();
2830 }
2831 
2832 //===----------------------------------------------------------------------===//
2833 // Spec 5.2: Masked construct (10.5)
2834 //===----------------------------------------------------------------------===//
2835 
2836 void MaskedOp::build(OpBuilder &builder, OperationState &state,
2837  const MaskedOperands &clauses) {
2838  MaskedOp::build(builder, state, clauses.filteredThreadId);
2839 }
2840 
2841 #define GET_ATTRDEF_CLASSES
2842 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2843 
2844 #define GET_OP_CLASSES
2845 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2846 
2847 #define GET_TYPEDEF_CLASSES
2848 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:722
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseBoolArrayAttr byref=nullptr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseBoolArrayAttr *byref=nullptr)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printInReductionMapPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseInReductionMapPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:215
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:760
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:871
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.