MLIR  20.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include <cstddef>
35 #include <iterator>
36 #include <optional>
37 #include <variant>
38 
39 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
43 
44 using namespace mlir;
45 using namespace mlir::omp;
46 
47 static ArrayAttr makeArrayAttr(MLIRContext *context,
49  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
50 }
51 
52 static DenseBoolArrayAttr
54  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
55 }
56 
57 namespace {
58 struct MemRefPointerLikeModel
59  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
60  MemRefType> {
61  Type getElementType(Type pointer) const {
62  return llvm::cast<MemRefType>(pointer).getElementType();
63  }
64 };
65 
66 struct LLVMPointerPointerLikeModel
67  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
68  LLVM::LLVMPointerType> {
69  Type getElementType(Type pointer) const { return Type(); }
70 };
71 } // namespace
72 
73 void OpenMPDialect::initialize() {
74  addOperations<
75 #define GET_OP_LIST
76 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
77  >();
78  addAttributes<
79 #define GET_ATTRDEF_LIST
80 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
81  >();
82  addTypes<
83 #define GET_TYPEDEF_LIST
84 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
85  >();
86 
87  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
88 
89  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
90  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
91  *getContext());
92 
93  // Attach default offload module interface to module op to access
94  // offload functionality through
95  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
96  *getContext());
97 
98  // Attach default declare target interfaces to operations which can be marked
99  // as declare target (Global Operations and Functions/Subroutines in dialects
100  // that Fortran (or other languages that lower to MLIR) translates too
101  mlir::LLVM::GlobalOp::attachInterface<
103  *getContext());
104  mlir::LLVM::LLVMFuncOp::attachInterface<
106  *getContext());
107  mlir::func::FuncOp::attachInterface<
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Parser and printer for Allocate Clause
113 //===----------------------------------------------------------------------===//
114 
115 /// Parse an allocate clause with allocators and a list of operands with types.
116 ///
117 /// allocate-operand-list :: = allocate-operand |
118 /// allocator-operand `,` allocate-operand-list
119 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
120 /// ssa-id-and-type ::= ssa-id `:` type
121 static ParseResult parseAllocateAndAllocator(
122  OpAsmParser &parser,
124  SmallVectorImpl<Type> &allocateTypes,
126  SmallVectorImpl<Type> &allocatorTypes) {
127 
128  return parser.parseCommaSeparatedList([&]() {
130  Type type;
131  if (parser.parseOperand(operand) || parser.parseColonType(type))
132  return failure();
133  allocatorVars.push_back(operand);
134  allocatorTypes.push_back(type);
135  if (parser.parseArrow())
136  return failure();
137  if (parser.parseOperand(operand) || parser.parseColonType(type))
138  return failure();
139 
140  allocateVars.push_back(operand);
141  allocateTypes.push_back(type);
142  return success();
143  });
144 }
145 
146 /// Print allocate clause
148  OperandRange allocateVars,
149  TypeRange allocateTypes,
150  OperandRange allocatorVars,
151  TypeRange allocatorTypes) {
152  for (unsigned i = 0; i < allocateVars.size(); ++i) {
153  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
154  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
155  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
156  }
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Parser and printer for a clause attribute (StringEnumAttr)
161 //===----------------------------------------------------------------------===//
162 
163 template <typename ClauseAttr>
164 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
165  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
166  StringRef enumStr;
167  SMLoc loc = parser.getCurrentLocation();
168  if (parser.parseKeyword(&enumStr))
169  return failure();
170  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
171  attr = ClauseAttr::get(parser.getContext(), *enumValue);
172  return success();
173  }
174  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
175 }
176 
177 template <typename ClauseAttr>
178 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
179  p << stringifyEnum(attr.getValue());
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Parser and printer for Linear Clause
184 //===----------------------------------------------------------------------===//
185 
186 /// linear ::= `linear` `(` linear-list `)`
187 /// linear-list := linear-val | linear-val linear-list
188 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
189 static ParseResult parseLinearClause(
190  OpAsmParser &parser,
192  SmallVectorImpl<Type> &linearTypes,
194  return parser.parseCommaSeparatedList([&]() {
196  Type type;
198  if (parser.parseOperand(var) || parser.parseEqual() ||
199  parser.parseOperand(stepVar) || parser.parseColonType(type))
200  return failure();
201 
202  linearVars.push_back(var);
203  linearTypes.push_back(type);
204  linearStepVars.push_back(stepVar);
205  return success();
206  });
207 }
208 
209 /// Print Linear Clause
211  ValueRange linearVars, TypeRange linearTypes,
212  ValueRange linearStepVars) {
213  size_t linearVarsSize = linearVars.size();
214  for (unsigned i = 0; i < linearVarsSize; ++i) {
215  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
216  p << linearVars[i];
217  if (linearStepVars.size() > i)
218  p << " = " << linearStepVars[i];
219  p << " : " << linearVars[i].getType() << separator;
220  }
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Verifier for Nontemporal Clause
225 //===----------------------------------------------------------------------===//
226 
227 static LogicalResult verifyNontemporalClause(Operation *op,
228  OperandRange nontemporalVars) {
229 
230  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
231  DenseSet<Value> nontemporalItems;
232  for (const auto &it : nontemporalVars)
233  if (!nontemporalItems.insert(it).second)
234  return op->emitOpError() << "nontemporal variable used more than once";
235 
236  return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Parser, verifier and printer for Aligned Clause
241 //===----------------------------------------------------------------------===//
242 static LogicalResult verifyAlignedClause(Operation *op,
243  std::optional<ArrayAttr> alignments,
244  OperandRange alignedVars) {
245  // Check if number of alignment values equals to number of aligned variables
246  if (!alignedVars.empty()) {
247  if (!alignments || alignments->size() != alignedVars.size())
248  return op->emitOpError()
249  << "expected as many alignment values as aligned variables";
250  } else {
251  if (alignments)
252  return op->emitOpError() << "unexpected alignment values attribute";
253  return success();
254  }
255 
256  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
257  DenseSet<Value> alignedItems;
258  for (auto it : alignedVars)
259  if (!alignedItems.insert(it).second)
260  return op->emitOpError() << "aligned variable used more than once";
261 
262  if (!alignments)
263  return success();
264 
265  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
266  for (unsigned i = 0; i < (*alignments).size(); ++i) {
267  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
268  if (intAttr.getValue().sle(0))
269  return op->emitOpError() << "alignment should be greater than 0";
270  } else {
271  return op->emitOpError() << "expected integer alignment";
272  }
273  }
274 
275  return success();
276 }
277 
278 /// aligned ::= `aligned` `(` aligned-list `)`
279 /// aligned-list := aligned-val | aligned-val aligned-list
280 /// aligned-val := ssa-id-and-type `->` alignment
281 static ParseResult
284  SmallVectorImpl<Type> &alignedTypes,
285  ArrayAttr &alignmentsAttr) {
286  SmallVector<Attribute> alignmentVec;
287  if (failed(parser.parseCommaSeparatedList([&]() {
288  if (parser.parseOperand(alignedVars.emplace_back()) ||
289  parser.parseColonType(alignedTypes.emplace_back()) ||
290  parser.parseArrow() ||
291  parser.parseAttribute(alignmentVec.emplace_back())) {
292  return failure();
293  }
294  return success();
295  })))
296  return failure();
297  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
298  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
299  return success();
300 }
301 
302 /// Print Aligned Clause
304  ValueRange alignedVars, TypeRange alignedTypes,
305  std::optional<ArrayAttr> alignments) {
306  for (unsigned i = 0; i < alignedVars.size(); ++i) {
307  if (i != 0)
308  p << ", ";
309  p << alignedVars[i] << " : " << alignedVars[i].getType();
310  p << " -> " << (*alignments)[i];
311  }
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Parser, printer and verifier for Schedule Clause
316 //===----------------------------------------------------------------------===//
317 
318 static ParseResult
320  SmallVectorImpl<SmallString<12>> &modifiers) {
321  if (modifiers.size() > 2)
322  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
323  for (const auto &mod : modifiers) {
324  // Translate the string. If it has no value, then it was not a valid
325  // modifier!
326  auto symbol = symbolizeScheduleModifier(mod);
327  if (!symbol)
328  return parser.emitError(parser.getNameLoc())
329  << " unknown modifier type: " << mod;
330  }
331 
332  // If we have one modifier that is "simd", then stick a "none" modiifer in
333  // index 0.
334  if (modifiers.size() == 1) {
335  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
336  modifiers.push_back(modifiers[0]);
337  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
338  }
339  } else if (modifiers.size() == 2) {
340  // If there are two modifier:
341  // First modifier should not be simd, second one should be simd
342  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
343  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
344  return parser.emitError(parser.getNameLoc())
345  << " incorrect modifier order";
346  }
347  return success();
348 }
349 
350 /// schedule ::= `schedule` `(` sched-list `)`
351 /// sched-list ::= sched-val | sched-val sched-list |
352 /// sched-val `,` sched-modifier
353 /// sched-val ::= sched-with-chunk | sched-wo-chunk
354 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
355 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
356 /// sched-wo-chunk ::= `auto` | `runtime`
357 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
358 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
359 static ParseResult
360 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
361  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
362  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
363  Type &chunkType) {
364  StringRef keyword;
365  if (parser.parseKeyword(&keyword))
366  return failure();
367  std::optional<mlir::omp::ClauseScheduleKind> schedule =
368  symbolizeClauseScheduleKind(keyword);
369  if (!schedule)
370  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
371 
372  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
373  switch (*schedule) {
374  case ClauseScheduleKind::Static:
375  case ClauseScheduleKind::Dynamic:
376  case ClauseScheduleKind::Guided:
377  if (succeeded(parser.parseOptionalEqual())) {
378  chunkSize = OpAsmParser::UnresolvedOperand{};
379  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
380  return failure();
381  } else {
382  chunkSize = std::nullopt;
383  }
384  break;
385  case ClauseScheduleKind::Auto:
387  chunkSize = std::nullopt;
388  }
389 
390  // If there is a comma, we have one or more modifiers..
391  SmallVector<SmallString<12>> modifiers;
392  while (succeeded(parser.parseOptionalComma())) {
393  StringRef mod;
394  if (parser.parseKeyword(&mod))
395  return failure();
396  modifiers.push_back(mod);
397  }
398 
399  if (verifyScheduleModifiers(parser, modifiers))
400  return failure();
401 
402  if (!modifiers.empty()) {
403  SMLoc loc = parser.getCurrentLocation();
404  if (std::optional<ScheduleModifier> mod =
405  symbolizeScheduleModifier(modifiers[0])) {
406  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
407  } else {
408  return parser.emitError(loc, "invalid schedule modifier");
409  }
410  // Only SIMD attribute is allowed here!
411  if (modifiers.size() > 1) {
412  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
413  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
414  }
415  }
416 
417  return success();
418 }
419 
420 /// Print schedule clause
422  ClauseScheduleKindAttr scheduleKind,
423  ScheduleModifierAttr scheduleMod,
424  UnitAttr scheduleSimd, Value scheduleChunk,
425  Type scheduleChunkType) {
426  p << stringifyClauseScheduleKind(scheduleKind.getValue());
427  if (scheduleChunk)
428  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
429  if (scheduleMod)
430  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
431  if (scheduleSimd)
432  p << ", simd";
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // Parser and printer for Order Clause
437 //===----------------------------------------------------------------------===//
438 
439 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
440 // order-modifier ::= reproducible | unconstrained
441 static ParseResult parseOrderClause(OpAsmParser &parser,
442  ClauseOrderKindAttr &order,
443  OrderModifierAttr &orderMod) {
444  StringRef enumStr;
445  SMLoc loc = parser.getCurrentLocation();
446  if (parser.parseKeyword(&enumStr))
447  return failure();
448  if (std::optional<OrderModifier> enumValue =
449  symbolizeOrderModifier(enumStr)) {
450  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
451  if (parser.parseOptionalColon())
452  return failure();
453  loc = parser.getCurrentLocation();
454  if (parser.parseKeyword(&enumStr))
455  return failure();
456  }
457  if (std::optional<ClauseOrderKind> enumValue =
458  symbolizeClauseOrderKind(enumStr)) {
459  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
460  return success();
461  }
462  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
463 }
464 
466  ClauseOrderKindAttr order,
467  OrderModifierAttr orderMod) {
468  if (orderMod)
469  p << stringifyOrderModifier(orderMod.getValue()) << ":";
470  if (order)
471  p << stringifyClauseOrderKind(order.getValue());
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // Parsers for operations including clauses that define entry block arguments.
476 //===----------------------------------------------------------------------===//
477 
478 namespace {
479 struct MapParseArgs {
481  SmallVectorImpl<Type> &types;
483  SmallVectorImpl<Type> &types)
484  : vars(vars), types(types) {}
485 };
486 struct PrivateParseArgs {
489  ArrayAttr &syms;
490  DenseI64ArrayAttr *mapIndices;
492  SmallVectorImpl<Type> &types, ArrayAttr &syms,
493  DenseI64ArrayAttr *mapIndices = nullptr)
494  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
495 };
496 struct ReductionParseArgs {
498  SmallVectorImpl<Type> &types;
499  DenseBoolArrayAttr &byref;
500  ArrayAttr &syms;
501  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
503  ArrayAttr &syms)
504  : vars(vars), types(types), byref(byref), syms(syms) {}
505 };
506 struct AllRegionParseArgs {
507  std::optional<ReductionParseArgs> inReductionArgs;
508  std::optional<MapParseArgs> mapArgs;
509  std::optional<PrivateParseArgs> privateArgs;
510  std::optional<ReductionParseArgs> reductionArgs;
511  std::optional<ReductionParseArgs> taskReductionArgs;
512  std::optional<MapParseArgs> useDeviceAddrArgs;
513  std::optional<MapParseArgs> useDevicePtrArgs;
514 };
515 } // namespace
516 
517 static ParseResult parseClauseWithRegionArgs(
518  OpAsmParser &parser,
520  SmallVectorImpl<Type> &types,
521  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
522  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
523  DenseBoolArrayAttr *byref = nullptr) {
524  SmallVector<SymbolRefAttr> symbolVec;
525  SmallVector<int64_t> mapIndicesVec;
526  SmallVector<bool> isByRefVec;
527  unsigned regionArgOffset = regionPrivateArgs.size();
528 
529  if (parser.parseLParen())
530  return failure();
531 
532  if (parser.parseCommaSeparatedList([&]() {
533  if (byref)
534  isByRefVec.push_back(
535  parser.parseOptionalKeyword("byref").succeeded());
536 
537  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
538  return failure();
539 
540  if (parser.parseOperand(operands.emplace_back()) ||
541  parser.parseArrow() ||
542  parser.parseArgument(regionPrivateArgs.emplace_back()))
543  return failure();
544 
545  if (mapIndices) {
546  if (parser.parseOptionalLSquare().succeeded()) {
547  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
548  parser.parseInteger(mapIndicesVec.emplace_back()) ||
549  parser.parseRSquare())
550  return failure();
551  } else
552  mapIndicesVec.push_back(-1);
553  }
554 
555  return success();
556  }))
557  return failure();
558 
559  if (parser.parseColon())
560  return failure();
561 
562  if (parser.parseCommaSeparatedList([&]() {
563  if (parser.parseType(types.emplace_back()))
564  return failure();
565 
566  return success();
567  }))
568  return failure();
569 
570  if (operands.size() != types.size())
571  return failure();
572 
573  if (parser.parseRParen())
574  return failure();
575 
576  auto *argsBegin = regionPrivateArgs.begin();
577  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
578  argsBegin + regionArgOffset + types.size());
579  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
580  prv.type = type;
581  }
582 
583  if (symbols) {
584  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
585  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
586  }
587 
588  if (!mapIndicesVec.empty())
589  *mapIndices =
590  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
591 
592  if (byref)
593  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
594 
595  return success();
596 }
597 
598 static ParseResult parseBlockArgClause(
599  OpAsmParser &parser,
601  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
602  if (succeeded(parser.parseOptionalKeyword(keyword))) {
603  if (!mapArgs)
604  return failure();
605 
606  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
607  entryBlockArgs)))
608  return failure();
609  }
610  return success();
611 }
612 
613 static ParseResult parseBlockArgClause(
614  OpAsmParser &parser,
616  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
617  if (succeeded(parser.parseOptionalKeyword(keyword))) {
618  if (!privateArgs)
619  return failure();
620 
621  if (failed(parseClauseWithRegionArgs(
622  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
623  &privateArgs->syms, privateArgs->mapIndices)))
624  return failure();
625  }
626  return success();
627 }
628 
629 static ParseResult parseBlockArgClause(
630  OpAsmParser &parser,
632  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
633  if (succeeded(parser.parseOptionalKeyword(keyword))) {
634  if (!reductionArgs)
635  return failure();
636 
637  if (failed(parseClauseWithRegionArgs(
638  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
639  &reductionArgs->syms, /*mapIndices=*/nullptr,
640  &reductionArgs->byref)))
641  return failure();
642  }
643  return success();
644 }
645 
646 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
647  AllRegionParseArgs args) {
649 
650  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
651  args.inReductionArgs)))
652  return parser.emitError(parser.getCurrentLocation())
653  << "invalid `in_reduction` format";
654 
655  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
656  args.mapArgs)))
657  return parser.emitError(parser.getCurrentLocation())
658  << "invalid `map_entries` format";
659 
660  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
661  args.privateArgs)))
662  return parser.emitError(parser.getCurrentLocation())
663  << "invalid `private` format";
664 
665  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
666  args.reductionArgs)))
667  return parser.emitError(parser.getCurrentLocation())
668  << "invalid `reduction` format";
669 
670  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
671  args.taskReductionArgs)))
672  return parser.emitError(parser.getCurrentLocation())
673  << "invalid `task_reduction` format";
674 
675  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
676  args.useDeviceAddrArgs)))
677  return parser.emitError(parser.getCurrentLocation())
678  << "invalid `use_device_addr` format";
679 
680  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
681  args.useDevicePtrArgs)))
682  return parser.emitError(parser.getCurrentLocation())
683  << "invalid `use_device_addr` format";
684 
685  return parser.parseRegion(region, entryBlockArgs);
686 }
687 
689  OpAsmParser &parser, Region &region,
691  SmallVectorImpl<Type> &inReductionTypes,
692  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
694  SmallVectorImpl<Type> &mapTypes,
696  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697  DenseI64ArrayAttr &privateMaps) {
698  AllRegionParseArgs args;
699  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
700  inReductionByref, inReductionSyms);
701  args.mapArgs.emplace(mapVars, mapTypes);
702  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
703  &privateMaps);
704  return parseBlockArgRegion(parser, region, args);
705 }
706 
707 static ParseResult parseInReductionPrivateRegion(
708  OpAsmParser &parser, Region &region,
710  SmallVectorImpl<Type> &inReductionTypes,
711  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
713  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
714  AllRegionParseArgs args;
715  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
716  inReductionByref, inReductionSyms);
717  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
718  return parseBlockArgRegion(parser, region, args);
719 }
720 
722  OpAsmParser &parser, Region &region,
724  SmallVectorImpl<Type> &inReductionTypes,
725  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
727  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
729  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
730  ArrayAttr &reductionSyms) {
731  AllRegionParseArgs args;
732  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
733  inReductionByref, inReductionSyms);
734  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
735  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
736  reductionSyms);
737  return parseBlockArgRegion(parser, region, args);
738 }
739 
740 static ParseResult parsePrivateRegion(
741  OpAsmParser &parser, Region &region,
743  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
744  AllRegionParseArgs args;
745  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
746  return parseBlockArgRegion(parser, region, args);
747 }
748 
749 static ParseResult parsePrivateReductionRegion(
750  OpAsmParser &parser, Region &region,
752  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
754  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
755  ArrayAttr &reductionSyms) {
756  AllRegionParseArgs args;
757  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
758  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
759  reductionSyms);
760  return parseBlockArgRegion(parser, region, args);
761 }
762 
763 static ParseResult parseTaskReductionRegion(
764  OpAsmParser &parser, Region &region,
766  SmallVectorImpl<Type> &taskReductionTypes,
767  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
768  AllRegionParseArgs args;
769  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
770  taskReductionByref, taskReductionSyms);
771  return parseBlockArgRegion(parser, region, args);
772 }
773 
775  OpAsmParser &parser, Region &region,
777  SmallVectorImpl<Type> &useDeviceAddrTypes,
779  SmallVectorImpl<Type> &useDevicePtrTypes) {
780  AllRegionParseArgs args;
781  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
782  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
783  return parseBlockArgRegion(parser, region, args);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // Printers for operations including clauses that define entry block arguments.
788 //===----------------------------------------------------------------------===//
789 
790 namespace {
791 struct MapPrintArgs {
792  ValueRange vars;
793  TypeRange types;
794  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
795 };
796 struct PrivatePrintArgs {
797  ValueRange vars;
798  TypeRange types;
799  ArrayAttr syms;
800  DenseI64ArrayAttr mapIndices;
801  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
802  DenseI64ArrayAttr mapIndices)
803  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
804 };
805 struct ReductionPrintArgs {
806  ValueRange vars;
807  TypeRange types;
808  DenseBoolArrayAttr byref;
809  ArrayAttr syms;
810  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
811  ArrayAttr syms)
812  : vars(vars), types(types), byref(byref), syms(syms) {}
813 };
814 struct AllRegionPrintArgs {
815  std::optional<ReductionPrintArgs> inReductionArgs;
816  std::optional<MapPrintArgs> mapArgs;
817  std::optional<PrivatePrintArgs> privateArgs;
818  std::optional<ReductionPrintArgs> reductionArgs;
819  std::optional<ReductionPrintArgs> taskReductionArgs;
820  std::optional<MapPrintArgs> useDeviceAddrArgs;
821  std::optional<MapPrintArgs> useDevicePtrArgs;
822 };
823 } // namespace
824 
826  StringRef clauseName,
827  ValueRange argsSubrange,
828  ValueRange operands, TypeRange types,
829  ArrayAttr symbols = nullptr,
830  DenseI64ArrayAttr mapIndices = nullptr,
831  DenseBoolArrayAttr byref = nullptr) {
832  if (argsSubrange.empty())
833  return;
834 
835  p << clauseName << "(";
836 
837  if (!symbols) {
838  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
839  symbols = ArrayAttr::get(ctx, values);
840  }
841 
842  if (!mapIndices) {
843  llvm::SmallVector<int64_t> values(operands.size(), -1);
844  mapIndices = DenseI64ArrayAttr::get(ctx, values);
845  }
846 
847  if (!byref) {
848  mlir::SmallVector<bool> values(operands.size(), false);
849  byref = DenseBoolArrayAttr::get(ctx, values);
850  }
851 
852  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
853  mapIndices.asArrayRef(),
854  byref.asArrayRef()),
855  p, [&p](auto t) {
856  auto [op, arg, sym, map, isByRef] = t;
857  if (isByRef)
858  p << "byref ";
859  if (sym)
860  p << sym << " ";
861 
862  p << op << " -> " << arg;
863 
864  if (map != -1)
865  p << " [map_idx=" << map << "]";
866  });
867  p << " : ";
868  llvm::interleaveComma(types, p);
869  p << ") ";
870 }
871 
873  StringRef clauseName, ValueRange argsSubrange,
874  std::optional<MapPrintArgs> mapArgs) {
875  if (mapArgs)
876  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
877  mapArgs->types);
878 }
879 
881  StringRef clauseName, ValueRange argsSubrange,
882  std::optional<PrivatePrintArgs> privateArgs) {
883  if (privateArgs)
884  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
885  privateArgs->vars, privateArgs->types,
886  privateArgs->syms, privateArgs->mapIndices);
887 }
888 
889 static void
890 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
891  ValueRange argsSubrange,
892  std::optional<ReductionPrintArgs> reductionArgs) {
893  if (reductionArgs)
894  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
895  reductionArgs->vars, reductionArgs->types,
896  reductionArgs->syms, /*mapIndices=*/nullptr,
897  reductionArgs->byref);
898 }
899 
900 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
901  const AllRegionPrintArgs &args) {
902  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
903  MLIRContext *ctx = op->getContext();
904 
905  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
906  args.inReductionArgs);
907  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
908  args.mapArgs);
909  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
910  args.privateArgs);
911  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
912  args.reductionArgs);
913  printBlockArgClause(p, ctx, "task_reduction",
914  iface.getTaskReductionBlockArgs(),
915  args.taskReductionArgs);
916  printBlockArgClause(p, ctx, "use_device_addr",
917  iface.getUseDeviceAddrBlockArgs(),
918  args.useDeviceAddrArgs);
919  printBlockArgClause(p, ctx, "use_device_ptr",
920  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
921 
922  p.printRegion(region, /*printEntryBlockArgs=*/false);
923 }
924 
926  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
927  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
928  ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
929  ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930  DenseI64ArrayAttr privateMaps) {
931  AllRegionPrintArgs args;
932  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
933  inReductionByref, inReductionSyms);
934  args.mapArgs.emplace(mapVars, mapTypes);
935  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
936  printBlockArgRegion(p, op, region, args);
937 }
938 
940  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
941  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
942  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
943  ArrayAttr privateSyms) {
944  AllRegionPrintArgs args;
945  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
946  inReductionByref, inReductionSyms);
947  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
948  /*mapIndices=*/nullptr);
949  printBlockArgRegion(p, op, region, args);
950 }
951 
953  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
954  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
955  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
956  ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
957  DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
958  AllRegionPrintArgs args;
959  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
960  inReductionByref, inReductionSyms);
961  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
962  /*mapIndices=*/nullptr);
963  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
964  reductionSyms);
965  printBlockArgRegion(p, op, region, args);
966 }
967 
968 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
969  ValueRange privateVars, TypeRange privateTypes,
970  ArrayAttr privateSyms) {
971  AllRegionPrintArgs args;
972  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
973  /*mapIndices=*/nullptr);
974  printBlockArgRegion(p, op, region, args);
975 }
976 
978  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
979  TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
980  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
981  ArrayAttr reductionSyms) {
982  AllRegionPrintArgs args;
983  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
984  /*mapIndices=*/nullptr);
985  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
986  reductionSyms);
987  printBlockArgRegion(p, op, region, args);
988 }
989 
991  Region &region,
992  ValueRange taskReductionVars,
993  TypeRange taskReductionTypes,
994  DenseBoolArrayAttr taskReductionByref,
995  ArrayAttr taskReductionSyms) {
996  AllRegionPrintArgs args;
997  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
998  taskReductionByref, taskReductionSyms);
999  printBlockArgRegion(p, op, region, args);
1000 }
1001 
1003  Region &region,
1004  ValueRange useDeviceAddrVars,
1005  TypeRange useDeviceAddrTypes,
1006  ValueRange useDevicePtrVars,
1007  TypeRange useDevicePtrTypes) {
1008  AllRegionPrintArgs args;
1009  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1010  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1011  printBlockArgRegion(p, op, region, args);
1012 }
1013 
1014 /// Verifies Reduction Clause
1015 static LogicalResult
1016 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1017  OperandRange reductionVars,
1018  std::optional<ArrayRef<bool>> reductionByref) {
1019  if (!reductionVars.empty()) {
1020  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1021  return op->emitOpError()
1022  << "expected as many reduction symbol references "
1023  "as reduction variables";
1024  if (reductionByref && reductionByref->size() != reductionVars.size())
1025  return op->emitError() << "expected as many reduction variable by "
1026  "reference attributes as reduction variables";
1027  } else {
1028  if (reductionSyms)
1029  return op->emitOpError() << "unexpected reduction symbol references";
1030  return success();
1031  }
1032 
1033  // TODO: The followings should be done in
1034  // SymbolUserOpInterface::verifySymbolUses.
1035  DenseSet<Value> accumulators;
1036  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1037  Value accum = std::get<0>(args);
1038 
1039  if (!accumulators.insert(accum).second)
1040  return op->emitOpError() << "accumulator variable used more than once";
1041 
1042  Type varType = accum.getType();
1043  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1044  auto decl =
1045  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1046  if (!decl)
1047  return op->emitOpError() << "expected symbol reference " << symbolRef
1048  << " to point to a reduction declaration";
1049 
1050  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1051  return op->emitOpError()
1052  << "expected accumulator (" << varType
1053  << ") to be the same type as reduction declaration ("
1054  << decl.getAccumulatorType() << ")";
1055  }
1056 
1057  return success();
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Parser, printer and verifier for Copyprivate
1062 //===----------------------------------------------------------------------===//
1063 
1064 /// copyprivate-entry-list ::= copyprivate-entry
1065 /// | copyprivate-entry-list `,` copyprivate-entry
1066 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1067 static ParseResult parseCopyprivate(
1068  OpAsmParser &parser,
1070  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1072  if (failed(parser.parseCommaSeparatedList([&]() {
1073  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1074  parser.parseArrow() ||
1075  parser.parseAttribute(symsVec.emplace_back()) ||
1076  parser.parseColonType(copyprivateTypes.emplace_back()))
1077  return failure();
1078  return success();
1079  })))
1080  return failure();
1081  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1082  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1083  return success();
1084 }
1085 
1086 /// Print Copyprivate clause
1088  OperandRange copyprivateVars,
1089  TypeRange copyprivateTypes,
1090  std::optional<ArrayAttr> copyprivateSyms) {
1091  if (!copyprivateSyms.has_value())
1092  return;
1093  llvm::interleaveComma(
1094  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1095  [&](const auto &args) {
1096  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1097  << std::get<2>(args);
1098  });
1099 }
1100 
1101 /// Verifies CopyPrivate Clause
1102 static LogicalResult
1104  std::optional<ArrayAttr> copyprivateSyms) {
1105  size_t copyprivateSymsSize =
1106  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1107  if (copyprivateSymsSize != copyprivateVars.size())
1108  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1109  << copyprivateVars.size()
1110  << ") and functions (= " << copyprivateSymsSize
1111  << "), both must be equal";
1112  if (!copyprivateSyms.has_value())
1113  return success();
1114 
1115  for (auto copyprivateVarAndSym :
1116  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1117  auto symbolRef =
1118  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1119  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1120  funcOp;
1121  if (mlir::func::FuncOp mlirFuncOp =
1122  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1123  symbolRef))
1124  funcOp = mlirFuncOp;
1125  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1126  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1127  op, symbolRef))
1128  funcOp = llvmFuncOp;
1129 
1130  auto getNumArguments = [&] {
1131  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1132  };
1133 
1134  auto getArgumentType = [&](unsigned i) {
1135  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1136  *funcOp);
1137  };
1138 
1139  if (!funcOp)
1140  return op->emitOpError() << "expected symbol reference " << symbolRef
1141  << " to point to a copy function";
1142 
1143  if (getNumArguments() != 2)
1144  return op->emitOpError()
1145  << "expected copy function " << symbolRef << " to have 2 operands";
1146 
1147  Type argTy = getArgumentType(0);
1148  if (argTy != getArgumentType(1))
1149  return op->emitOpError() << "expected copy function " << symbolRef
1150  << " arguments to have the same type";
1151 
1152  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1153  if (argTy != varType)
1154  return op->emitOpError()
1155  << "expected copy function arguments' type (" << argTy
1156  << ") to be the same as copyprivate variable's type (" << varType
1157  << ")";
1158  }
1159 
1160  return success();
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // Parser, printer and verifier for DependVarList
1165 //===----------------------------------------------------------------------===//
1166 
1167 /// depend-entry-list ::= depend-entry
1168 /// | depend-entry-list `,` depend-entry
1169 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1170 static ParseResult
1173  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1175  if (failed(parser.parseCommaSeparatedList([&]() {
1176  StringRef keyword;
1177  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1178  parser.parseOperand(dependVars.emplace_back()) ||
1179  parser.parseColonType(dependTypes.emplace_back()))
1180  return failure();
1181  if (std::optional<ClauseTaskDepend> keywordDepend =
1182  (symbolizeClauseTaskDepend(keyword)))
1183  kindsVec.emplace_back(
1184  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1185  else
1186  return failure();
1187  return success();
1188  })))
1189  return failure();
1190  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1191  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1192  return success();
1193 }
1194 
1195 /// Print Depend clause
1197  OperandRange dependVars, TypeRange dependTypes,
1198  std::optional<ArrayAttr> dependKinds) {
1199 
1200  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1201  if (i != 0)
1202  p << ", ";
1203  p << stringifyClauseTaskDepend(
1204  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1205  .getValue())
1206  << " -> " << dependVars[i] << " : " << dependTypes[i];
1207  }
1208 }
1209 
1210 /// Verifies Depend clause
1211 static LogicalResult verifyDependVarList(Operation *op,
1212  std::optional<ArrayAttr> dependKinds,
1213  OperandRange dependVars) {
1214  if (!dependVars.empty()) {
1215  if (!dependKinds || dependKinds->size() != dependVars.size())
1216  return op->emitOpError() << "expected as many depend values"
1217  " as depend variables";
1218  } else {
1219  if (dependKinds && !dependKinds->empty())
1220  return op->emitOpError() << "unexpected depend values";
1221  return success();
1222  }
1223 
1224  return success();
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Parses a Synchronization Hint clause. The value of hint is an integer
1232 /// which is a combination of different hints from `omp_sync_hint_t`.
1233 ///
1234 /// hint-clause = `hint` `(` hint-value `)`
1235 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1236  IntegerAttr &hintAttr) {
1237  StringRef hintKeyword;
1238  int64_t hint = 0;
1239  if (succeeded(parser.parseOptionalKeyword("none"))) {
1240  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1241  return success();
1242  }
1243  auto parseKeyword = [&]() -> ParseResult {
1244  if (failed(parser.parseKeyword(&hintKeyword)))
1245  return failure();
1246  if (hintKeyword == "uncontended")
1247  hint |= 1;
1248  else if (hintKeyword == "contended")
1249  hint |= 2;
1250  else if (hintKeyword == "nonspeculative")
1251  hint |= 4;
1252  else if (hintKeyword == "speculative")
1253  hint |= 8;
1254  else
1255  return parser.emitError(parser.getCurrentLocation())
1256  << hintKeyword << " is not a valid hint";
1257  return success();
1258  };
1259  if (parser.parseCommaSeparatedList(parseKeyword))
1260  return failure();
1261  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1262  return success();
1263 }
1264 
1265 /// Prints a Synchronization Hint clause
1267  IntegerAttr hintAttr) {
1268  int64_t hint = hintAttr.getInt();
1269 
1270  if (hint == 0) {
1271  p << "none";
1272  return;
1273  }
1274 
1275  // Helper function to get n-th bit from the right end of `value`
1276  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1277 
1278  bool uncontended = bitn(hint, 0);
1279  bool contended = bitn(hint, 1);
1280  bool nonspeculative = bitn(hint, 2);
1281  bool speculative = bitn(hint, 3);
1282 
1283  SmallVector<StringRef> hints;
1284  if (uncontended)
1285  hints.push_back("uncontended");
1286  if (contended)
1287  hints.push_back("contended");
1288  if (nonspeculative)
1289  hints.push_back("nonspeculative");
1290  if (speculative)
1291  hints.push_back("speculative");
1292 
1293  llvm::interleaveComma(hints, p);
1294 }
1295 
1296 /// Verifies a synchronization hint clause
1297 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1298 
1299  // Helper function to get n-th bit from the right end of `value`
1300  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1301 
1302  bool uncontended = bitn(hint, 0);
1303  bool contended = bitn(hint, 1);
1304  bool nonspeculative = bitn(hint, 2);
1305  bool speculative = bitn(hint, 3);
1306 
1307  if (uncontended && contended)
1308  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1309  "omp_sync_hint_contended cannot be combined";
1310  if (nonspeculative && speculative)
1311  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1312  "omp_sync_hint_speculative cannot be combined.";
1313  return success();
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // Parser, printer and verifier for Target
1318 //===----------------------------------------------------------------------===//
1319 
1320 // Helper function to get bitwise AND of `value` and 'flag'
1321 uint64_t mapTypeToBitFlag(uint64_t value,
1322  llvm::omp::OpenMPOffloadMappingFlags flag) {
1323  return value & llvm::to_underlying(flag);
1324 }
1325 
1326 /// Parses a map_entries map type from a string format back into its numeric
1327 /// value.
1328 ///
1329 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
1330 /// `to` | `from` | `delete` `)` )+ `)` )
1331 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1332  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1333  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1334 
1335  // This simply verifies the correct keyword is read in, the
1336  // keyword itself is stored inside of the operation
1337  auto parseTypeAndMod = [&]() -> ParseResult {
1338  StringRef mapTypeMod;
1339  if (parser.parseKeyword(&mapTypeMod))
1340  return failure();
1341 
1342  if (mapTypeMod == "always")
1343  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1344 
1345  if (mapTypeMod == "implicit")
1346  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1347 
1348  if (mapTypeMod == "close")
1349  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1350 
1351  if (mapTypeMod == "present")
1352  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1353 
1354  if (mapTypeMod == "to")
1355  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1356 
1357  if (mapTypeMod == "from")
1358  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1359 
1360  if (mapTypeMod == "tofrom")
1361  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1362  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1363 
1364  if (mapTypeMod == "delete")
1365  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1366 
1367  return success();
1368  };
1369 
1370  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1371  return failure();
1372 
1373  mapType = parser.getBuilder().getIntegerAttr(
1374  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1375  llvm::to_underlying(mapTypeBits));
1376 
1377  return success();
1378 }
1379 
1380 /// Prints a map_entries map type from its numeric value out into its string
1381 /// format.
1383  IntegerAttr mapType) {
1384  uint64_t mapTypeBits = mapType.getUInt();
1385 
1386  bool emitAllocRelease = true;
1388 
1389  // handling of always, close, present placed at the beginning of the string
1390  // to aid readability
1391  if (mapTypeToBitFlag(mapTypeBits,
1392  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1393  mapTypeStrs.push_back("always");
1394  if (mapTypeToBitFlag(mapTypeBits,
1395  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1396  mapTypeStrs.push_back("implicit");
1397  if (mapTypeToBitFlag(mapTypeBits,
1398  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1399  mapTypeStrs.push_back("close");
1400  if (mapTypeToBitFlag(mapTypeBits,
1401  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1402  mapTypeStrs.push_back("present");
1403 
1404  // special handling of to/from/tofrom/delete and release/alloc, release +
1405  // alloc are the abscense of one of the other flags, whereas tofrom requires
1406  // both the to and from flag to be set.
1407  bool to = mapTypeToBitFlag(mapTypeBits,
1408  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1409  bool from = mapTypeToBitFlag(
1410  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1411  if (to && from) {
1412  emitAllocRelease = false;
1413  mapTypeStrs.push_back("tofrom");
1414  } else if (from) {
1415  emitAllocRelease = false;
1416  mapTypeStrs.push_back("from");
1417  } else if (to) {
1418  emitAllocRelease = false;
1419  mapTypeStrs.push_back("to");
1420  }
1421  if (mapTypeToBitFlag(mapTypeBits,
1422  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1423  emitAllocRelease = false;
1424  mapTypeStrs.push_back("delete");
1425  }
1426  if (emitAllocRelease)
1427  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1428 
1429  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1430  p << mapTypeStrs[i];
1431  if (i + 1 < mapTypeStrs.size()) {
1432  p << ", ";
1433  }
1434  }
1435 }
1436 
1437 static ParseResult parseMembersIndex(OpAsmParser &parser,
1438  ArrayAttr &membersIdx) {
1439  SmallVector<Attribute> values, memberIdxs;
1440 
1441  auto parseIndices = [&]() -> ParseResult {
1442  int64_t value;
1443  if (parser.parseInteger(value))
1444  return failure();
1445  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1446  APInt(64, value, /*isSigned=*/false)));
1447  return success();
1448  };
1449 
1450  do {
1451  if (failed(parser.parseLSquare()))
1452  return failure();
1453 
1454  if (parser.parseCommaSeparatedList(parseIndices))
1455  return failure();
1456 
1457  if (failed(parser.parseRSquare()))
1458  return failure();
1459 
1460  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1461  values.clear();
1462  } while (succeeded(parser.parseOptionalComma()));
1463 
1464  if (!memberIdxs.empty())
1465  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1466 
1467  return success();
1468 }
1469 
1470 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1471  ArrayAttr membersIdx) {
1472  if (!membersIdx)
1473  return;
1474 
1475  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1476  p << "[";
1477  auto memberIdx = cast<ArrayAttr>(v);
1478  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1479  p << cast<IntegerAttr>(v2).getInt();
1480  });
1481  p << "]";
1482  });
1483 }
1484 
1486  VariableCaptureKindAttr mapCaptureType) {
1487  std::string typeCapStr;
1488  llvm::raw_string_ostream typeCap(typeCapStr);
1489  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1490  typeCap << "ByRef";
1491  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1492  typeCap << "ByCopy";
1493  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1494  typeCap << "VLAType";
1495  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1496  typeCap << "This";
1497  p << typeCapStr;
1498 }
1499 
1500 static ParseResult parseCaptureType(OpAsmParser &parser,
1501  VariableCaptureKindAttr &mapCaptureType) {
1502  StringRef mapCaptureKey;
1503  if (parser.parseKeyword(&mapCaptureKey))
1504  return failure();
1505 
1506  if (mapCaptureKey == "This")
1507  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1508  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1509  if (mapCaptureKey == "ByRef")
1510  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1511  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1512  if (mapCaptureKey == "ByCopy")
1513  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1514  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1515  if (mapCaptureKey == "VLAType")
1516  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1517  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1518 
1519  return success();
1520 }
1521 
1522 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1525 
1526  for (auto mapOp : mapVars) {
1527  if (!mapOp.getDefiningOp())
1528  emitError(op->getLoc(), "missing map operation");
1529 
1530  if (auto mapInfoOp =
1531  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1532  if (!mapInfoOp.getMapType().has_value())
1533  emitError(op->getLoc(), "missing map type for map operand");
1534 
1535  if (!mapInfoOp.getMapCaptureType().has_value())
1536  emitError(op->getLoc(), "missing map capture type for map operand");
1537 
1538  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1539 
1540  bool to = mapTypeToBitFlag(
1541  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1542  bool from = mapTypeToBitFlag(
1543  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1544  bool del = mapTypeToBitFlag(
1545  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1546 
1547  bool always = mapTypeToBitFlag(
1548  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1549  bool close = mapTypeToBitFlag(
1550  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1551  bool implicit = mapTypeToBitFlag(
1552  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1553 
1554  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1555  return emitError(op->getLoc(),
1556  "to, from, tofrom and alloc map types are permitted");
1557 
1558  if (isa<TargetEnterDataOp>(op) && (from || del))
1559  return emitError(op->getLoc(), "to and alloc map types are permitted");
1560 
1561  if (isa<TargetExitDataOp>(op) && to)
1562  return emitError(op->getLoc(),
1563  "from, release and delete map types are permitted");
1564 
1565  if (isa<TargetUpdateOp>(op)) {
1566  if (del) {
1567  return emitError(op->getLoc(),
1568  "at least one of to or from map types must be "
1569  "specified, other map types are not permitted");
1570  }
1571 
1572  if (!to && !from) {
1573  return emitError(op->getLoc(),
1574  "at least one of to or from map types must be "
1575  "specified, other map types are not permitted");
1576  }
1577 
1578  auto updateVar = mapInfoOp.getVarPtr();
1579 
1580  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1581  (from && updateToVars.contains(updateVar))) {
1582  return emitError(
1583  op->getLoc(),
1584  "either to or from map types can be specified, not both");
1585  }
1586 
1587  if (always || close || implicit) {
1588  return emitError(
1589  op->getLoc(),
1590  "present, mapper and iterator map type modifiers are permitted");
1591  }
1592 
1593  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1594  }
1595  } else {
1596  emitError(op->getLoc(), "map argument is not a map entry operation");
1597  }
1598  }
1599 
1600  return success();
1601 }
1602 
1603 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1604  std::optional<DenseI64ArrayAttr> privateMapIndices =
1605  targetOp.getPrivateMapsAttr();
1606 
1607  // None of the private operands are mapped.
1608  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1609  return success();
1610 
1611  OperandRange privateVars = targetOp.getPrivateVars();
1612 
1613  if (privateMapIndices.value().size() !=
1614  static_cast<int64_t>(privateVars.size()))
1615  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1616  "`private_maps` attribute mismatch");
1617 
1618  return success();
1619 }
1620 
1621 //===----------------------------------------------------------------------===//
1622 // TargetDataOp
1623 //===----------------------------------------------------------------------===//
1624 
1625 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1626  const TargetDataOperands &clauses) {
1627  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1628  clauses.mapVars, clauses.useDeviceAddrVars,
1629  clauses.useDevicePtrVars);
1630 }
1631 
1632 LogicalResult TargetDataOp::verify() {
1633  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1634  getUseDeviceAddrVars().empty()) {
1635  return ::emitError(this->getLoc(),
1636  "At least one of map, use_device_ptr_vars, or "
1637  "use_device_addr_vars operand must be present");
1638  }
1639  return verifyMapClause(*this, getMapVars());
1640 }
1641 
1642 //===----------------------------------------------------------------------===//
1643 // TargetEnterDataOp
1644 //===----------------------------------------------------------------------===//
1645 
1646 void TargetEnterDataOp::build(
1647  OpBuilder &builder, OperationState &state,
1648  const TargetEnterExitUpdateDataOperands &clauses) {
1649  MLIRContext *ctx = builder.getContext();
1650  TargetEnterDataOp::build(builder, state,
1651  makeArrayAttr(ctx, clauses.dependKinds),
1652  clauses.dependVars, clauses.device, clauses.ifExpr,
1653  clauses.mapVars, clauses.nowait);
1654 }
1655 
1656 LogicalResult TargetEnterDataOp::verify() {
1657  LogicalResult verifyDependVars =
1658  verifyDependVarList(*this, getDependKinds(), getDependVars());
1659  return failed(verifyDependVars) ? verifyDependVars
1660  : verifyMapClause(*this, getMapVars());
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // TargetExitDataOp
1665 //===----------------------------------------------------------------------===//
1666 
1667 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1668  const TargetEnterExitUpdateDataOperands &clauses) {
1669  MLIRContext *ctx = builder.getContext();
1670  TargetExitDataOp::build(builder, state,
1671  makeArrayAttr(ctx, clauses.dependKinds),
1672  clauses.dependVars, clauses.device, clauses.ifExpr,
1673  clauses.mapVars, clauses.nowait);
1674 }
1675 
1676 LogicalResult TargetExitDataOp::verify() {
1677  LogicalResult verifyDependVars =
1678  verifyDependVarList(*this, getDependKinds(), getDependVars());
1679  return failed(verifyDependVars) ? verifyDependVars
1680  : verifyMapClause(*this, getMapVars());
1681 }
1682 
1683 //===----------------------------------------------------------------------===//
1684 // TargetUpdateOp
1685 //===----------------------------------------------------------------------===//
1686 
1687 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1688  const TargetEnterExitUpdateDataOperands &clauses) {
1689  MLIRContext *ctx = builder.getContext();
1690  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1691  clauses.dependVars, clauses.device, clauses.ifExpr,
1692  clauses.mapVars, clauses.nowait);
1693 }
1694 
1695 LogicalResult TargetUpdateOp::verify() {
1696  LogicalResult verifyDependVars =
1697  verifyDependVarList(*this, getDependKinds(), getDependVars());
1698  return failed(verifyDependVars) ? verifyDependVars
1699  : verifyMapClause(*this, getMapVars());
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // TargetOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 void TargetOp::build(OpBuilder &builder, OperationState &state,
1707  const TargetOperands &clauses) {
1708  MLIRContext *ctx = builder.getContext();
1709  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1710  // inReductionByref, inReductionSyms.
1711  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1712  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1713  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1714  clauses.ifExpr, /*in_reduction_vars=*/{},
1715  /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr,
1716  clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait,
1717  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
1718  clauses.threadLimit, /*private_maps=*/nullptr);
1719 }
1720 
1721 LogicalResult TargetOp::verify() {
1722  LogicalResult verifyDependVars =
1723  verifyDependVarList(*this, getDependKinds(), getDependVars());
1724 
1725  if (failed(verifyDependVars))
1726  return verifyDependVars;
1727 
1728  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1729 
1730  if (failed(verifyMapVars))
1731  return verifyMapVars;
1732 
1733  return verifyPrivateVarsMapping(*this);
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // ParallelOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1741  ArrayRef<NamedAttribute> attributes) {
1742  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
1743  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
1744  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
1745  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
1746  /*reduction_vars=*/ValueRange(),
1747  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
1748  state.addAttributes(attributes);
1749 }
1750 
1751 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1752  const ParallelOperands &clauses) {
1753  MLIRContext *ctx = builder.getContext();
1754  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1755  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
1756  makeArrayAttr(ctx, clauses.privateSyms),
1757  clauses.procBindKind, clauses.reductionVars,
1758  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1759  makeArrayAttr(ctx, clauses.reductionSyms));
1760 }
1761 
1762 template <typename OpType>
1763 static LogicalResult verifyPrivateVarList(OpType &op) {
1764  auto privateVars = op.getPrivateVars();
1765  auto privateSyms = op.getPrivateSymsAttr();
1766 
1767  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
1768  return success();
1769 
1770  auto numPrivateVars = privateVars.size();
1771  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
1772 
1773  if (numPrivateVars != numPrivateSyms)
1774  return op.emitError() << "inconsistent number of private variables and "
1775  "privatizer op symbols, private vars: "
1776  << numPrivateVars
1777  << " vs. privatizer op symbols: " << numPrivateSyms;
1778 
1779  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
1780  Type varType = std::get<0>(privateVarInfo).getType();
1781  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1782  PrivateClauseOp privatizerOp =
1783  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
1784 
1785  if (privatizerOp == nullptr)
1786  return op.emitError() << "failed to lookup privatizer op with symbol: '"
1787  << privateSym << "'";
1788 
1789  Type privatizerType = privatizerOp.getType();
1790 
1791  if (varType != privatizerType)
1792  return op.emitError()
1793  << "type mismatch between a "
1794  << (privatizerOp.getDataSharingType() ==
1795  DataSharingClauseType::Private
1796  ? "private"
1797  : "firstprivate")
1798  << " variable and its privatizer op, var type: " << varType
1799  << " vs. privatizer op type: " << privatizerType;
1800  }
1801 
1802  return success();
1803 }
1804 
1805 LogicalResult ParallelOp::verify() {
1806  if (getAllocateVars().size() != getAllocatorVars().size())
1807  return emitError(
1808  "expected equal sizes for allocate and allocator variables");
1809 
1810  if (failed(verifyPrivateVarList(*this)))
1811  return failure();
1812 
1813  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1814  getReductionByref());
1815 }
1816 
1817 LogicalResult ParallelOp::verifyRegions() {
1818  auto distributeChildOps = getOps<DistributeOp>();
1819  if (!distributeChildOps.empty()) {
1820  if (!isComposite())
1821  return emitError()
1822  << "'omp.composite' attribute missing from composite operation";
1823 
1824  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
1825  Operation &distributeOp = **distributeChildOps.begin();
1826  for (Operation &childOp : getOps()) {
1827  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
1828  continue;
1829 
1830  if (!childOp.hasTrait<OpTrait::IsTerminator>())
1831  return emitError() << "unexpected OpenMP operation inside of composite "
1832  "'omp.parallel'";
1833  }
1834  } else if (isComposite()) {
1835  return emitError()
1836  << "'omp.composite' attribute present in non-composite operation";
1837  }
1838  return success();
1839 }
1840 
1841 //===----------------------------------------------------------------------===//
1842 // TeamsOp
1843 //===----------------------------------------------------------------------===//
1844 
1846  while ((op = op->getParentOp()))
1847  if (isa<OpenMPDialect>(op->getDialect()))
1848  return false;
1849  return true;
1850 }
1851 
1852 void TeamsOp::build(OpBuilder &builder, OperationState &state,
1853  const TeamsOperands &clauses) {
1854  MLIRContext *ctx = builder.getContext();
1855  // TODO Store clauses in op: privateVars, privateSyms.
1856  TeamsOp::build(
1857  builder, state, clauses.allocateVars, clauses.allocatorVars,
1858  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
1859  /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
1860  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1861  makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
1862 }
1863 
1864 LogicalResult TeamsOp::verify() {
1865  // Check parent region
1866  // TODO If nested inside of a target region, also check that it does not
1867  // contain any statements, declarations or directives other than this
1868  // omp.teams construct. The issue is how to support the initialization of
1869  // this operation's own arguments (allow SSA values across omp.target?).
1870  Operation *op = getOperation();
1871  if (!isa<TargetOp>(op->getParentOp()) &&
1873  return emitError("expected to be nested inside of omp.target or not nested "
1874  "in any OpenMP dialect operations");
1875 
1876  // Check for num_teams clause restrictions
1877  if (auto numTeamsLowerBound = getNumTeamsLower()) {
1878  auto numTeamsUpperBound = getNumTeamsUpper();
1879  if (!numTeamsUpperBound)
1880  return emitError("expected num_teams upper bound to be defined if the "
1881  "lower bound is defined");
1882  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1883  return emitError(
1884  "expected num_teams upper bound and lower bound to be the same type");
1885  }
1886 
1887  // Check for allocate clause restrictions
1888  if (getAllocateVars().size() != getAllocatorVars().size())
1889  return emitError(
1890  "expected equal sizes for allocate and allocator variables");
1891 
1892  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1893  getReductionByref());
1894 }
1895 
1896 //===----------------------------------------------------------------------===//
1897 // SectionOp
1898 //===----------------------------------------------------------------------===//
1899 
1900 unsigned SectionOp::numPrivateBlockArgs() {
1901  return getParentOp().numPrivateBlockArgs();
1902 }
1903 
1904 unsigned SectionOp::numReductionBlockArgs() {
1905  return getParentOp().numReductionBlockArgs();
1906 }
1907 
1908 //===----------------------------------------------------------------------===//
1909 // SectionsOp
1910 //===----------------------------------------------------------------------===//
1911 
1912 void SectionsOp::build(OpBuilder &builder, OperationState &state,
1913  const SectionsOperands &clauses) {
1914  MLIRContext *ctx = builder.getContext();
1915  // TODO Store clauses in op: privateVars, privateSyms.
1916  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1917  clauses.nowait, /*private_vars=*/{},
1918  /*private_syms=*/nullptr, clauses.reductionVars,
1919  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1920  makeArrayAttr(ctx, clauses.reductionSyms));
1921 }
1922 
1923 LogicalResult SectionsOp::verify() {
1924  if (getAllocateVars().size() != getAllocatorVars().size())
1925  return emitError(
1926  "expected equal sizes for allocate and allocator variables");
1927 
1928  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1929  getReductionByref());
1930 }
1931 
1932 LogicalResult SectionsOp::verifyRegions() {
1933  for (auto &inst : *getRegion().begin()) {
1934  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1935  return emitOpError()
1936  << "expected omp.section op or terminator op inside region";
1937  }
1938  }
1939 
1940  return success();
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // SingleOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 void SingleOp::build(OpBuilder &builder, OperationState &state,
1948  const SingleOperands &clauses) {
1949  MLIRContext *ctx = builder.getContext();
1950  // TODO Store clauses in op: privateVars, privateSyms.
1951  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1952  clauses.copyprivateVars,
1953  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
1954  /*private_vars=*/{}, /*private_syms=*/nullptr);
1955 }
1956 
1957 LogicalResult SingleOp::verify() {
1958  // Check for allocate clause restrictions
1959  if (getAllocateVars().size() != getAllocatorVars().size())
1960  return emitError(
1961  "expected equal sizes for allocate and allocator variables");
1962 
1963  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
1964  getCopyprivateSyms());
1965 }
1966 
1967 //===----------------------------------------------------------------------===//
1968 // WorkshareOp
1969 //===----------------------------------------------------------------------===//
1970 
1971 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
1972  const WorkshareOperands &clauses) {
1973  WorkshareOp::build(builder, state, clauses.nowait);
1974 }
1975 
1976 //===----------------------------------------------------------------------===//
1977 // WorkshareLoopWrapperOp
1978 //===----------------------------------------------------------------------===//
1979 
1980 LogicalResult WorkshareLoopWrapperOp::verify() {
1981  if (!(*this)->getParentOfType<WorkshareOp>())
1982  return emitError() << "must be nested in an omp.workshare";
1983  if (getNestedWrapper())
1984  return emitError() << "cannot be composite";
1985  return success();
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // LoopWrapperInterface
1990 //===----------------------------------------------------------------------===//
1991 
1992 LogicalResult LoopWrapperInterface::verifyImpl() {
1993  Operation *op = this->getOperation();
1994  if (!op->hasTrait<OpTrait::NoTerminator>() ||
1996  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
1997  "and `SingleBlock` traits";
1998 
1999  if (op->getNumRegions() != 1)
2000  return emitOpError() << "loop wrapper does not contain exactly one region";
2001 
2002  Region &region = op->getRegion(0);
2003  if (range_size(region.getOps()) != 1)
2004  return emitOpError()
2005  << "loop wrapper does not contain exactly one nested op";
2006 
2007  Operation &firstOp = *region.op_begin();
2008  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2009  return emitOpError() << "op nested in loop wrapper is not another loop "
2010  "wrapper or `omp.loop_nest`";
2011 
2012  return success();
2013 }
2014 
2015 //===----------------------------------------------------------------------===//
2016 // LoopOp
2017 //===----------------------------------------------------------------------===//
2018 
2019 void LoopOp::build(OpBuilder &builder, OperationState &state,
2020  const LoopOperands &clauses) {
2021  MLIRContext *ctx = builder.getContext();
2022 
2023  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2024  makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
2025  clauses.orderMod, clauses.reductionVars,
2026  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2027  makeArrayAttr(ctx, clauses.reductionSyms));
2028 }
2029 
2030 LogicalResult LoopOp::verify() {
2031  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2032  getReductionByref());
2033 }
2034 
2035 LogicalResult LoopOp::verifyRegions() {
2036  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2037  getNestedWrapper())
2038  return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
2039 
2040  return success();
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // WsloopOp
2045 //===----------------------------------------------------------------------===//
2046 
2047 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2048  ArrayRef<NamedAttribute> attributes) {
2049  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2050  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2051  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2052  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2053  /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
2054  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2055  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2056  /*schedule_simd=*/false);
2057  state.addAttributes(attributes);
2058 }
2059 
2060 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2061  const WsloopOperands &clauses) {
2062  MLIRContext *ctx = builder.getContext();
2063  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
2064  // privateSyms.
2065  WsloopOp::build(
2066  builder, state,
2067  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2068  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2069  clauses.ordered, clauses.privateVars,
2070  makeArrayAttr(ctx, clauses.privateSyms), clauses.reductionVars,
2071  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2072  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2073  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2074 }
2075 
2076 LogicalResult WsloopOp::verify() {
2077  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2078  getReductionByref());
2079 }
2080 
2081 LogicalResult WsloopOp::verifyRegions() {
2082  bool isCompositeChildLeaf =
2083  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2084 
2085  if (LoopWrapperInterface nested = getNestedWrapper()) {
2086  if (!isComposite())
2087  return emitError()
2088  << "'omp.composite' attribute missing from composite wrapper";
2089 
2090  // Check for the allowed leaf constructs that may appear in a composite
2091  // construct directly after DO/FOR.
2092  if (!isa<SimdOp>(nested))
2093  return emitError() << "only supported nested wrapper is 'omp.simd'";
2094 
2095  } else if (isComposite() && !isCompositeChildLeaf) {
2096  return emitError()
2097  << "'omp.composite' attribute present in non-composite wrapper";
2098  } else if (!isComposite() && isCompositeChildLeaf) {
2099  return emitError()
2100  << "'omp.composite' attribute missing from composite wrapper";
2101  }
2102 
2103  return success();
2104 }
2105 
2106 //===----------------------------------------------------------------------===//
2107 // Simd construct [2.9.3.1]
2108 //===----------------------------------------------------------------------===//
2109 
2110 void SimdOp::build(OpBuilder &builder, OperationState &state,
2111  const SimdOperands &clauses) {
2112  MLIRContext *ctx = builder.getContext();
2113  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2114  // privateSyms.
2115  SimdOp::build(builder, state, clauses.alignedVars,
2116  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2117  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2118  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2119  /*private_vars=*/{}, /*private_syms=*/nullptr,
2120  clauses.reductionVars,
2121  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2122  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2123  clauses.simdlen);
2124 }
2125 
2126 LogicalResult SimdOp::verify() {
2127  if (getSimdlen().has_value() && getSafelen().has_value() &&
2128  getSimdlen().value() > getSafelen().value())
2129  return emitOpError()
2130  << "simdlen clause and safelen clause are both present, but the "
2131  "simdlen value is not less than or equal to safelen value";
2132 
2133  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2134  return failure();
2135 
2136  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2137  return failure();
2138 
2139  bool isCompositeChildLeaf =
2140  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2141 
2142  if (!isComposite() && isCompositeChildLeaf)
2143  return emitError()
2144  << "'omp.composite' attribute missing from composite wrapper";
2145 
2146  if (isComposite() && !isCompositeChildLeaf)
2147  return emitError()
2148  << "'omp.composite' attribute present in non-composite wrapper";
2149 
2150  return success();
2151 }
2152 
2153 LogicalResult SimdOp::verifyRegions() {
2154  if (getNestedWrapper())
2155  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2156 
2157  return success();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // Distribute construct [2.9.4.1]
2162 //===----------------------------------------------------------------------===//
2163 
2164 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2165  const DistributeOperands &clauses) {
2166  DistributeOp::build(builder, state, clauses.allocateVars,
2167  clauses.allocatorVars, clauses.distScheduleStatic,
2168  clauses.distScheduleChunkSize, clauses.order,
2169  clauses.orderMod, clauses.privateVars,
2170  makeArrayAttr(builder.getContext(), clauses.privateSyms));
2171 }
2172 
2173 LogicalResult DistributeOp::verify() {
2174  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2175  return emitOpError() << "chunk size set without "
2176  "dist_schedule_static being present";
2177 
2178  if (getAllocateVars().size() != getAllocatorVars().size())
2179  return emitError(
2180  "expected equal sizes for allocate and allocator variables");
2181 
2182  return success();
2183 }
2184 
2185 LogicalResult DistributeOp::verifyRegions() {
2186  if (LoopWrapperInterface nested = getNestedWrapper()) {
2187  if (!isComposite())
2188  return emitError()
2189  << "'omp.composite' attribute missing from composite wrapper";
2190  // Check for the allowed leaf constructs that may appear in a composite
2191  // construct directly after DISTRIBUTE.
2192  if (isa<WsloopOp>(nested)) {
2193  if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
2194  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2195  "when 'omp.parallel' is the direct parent";
2196  } else if (!isa<SimdOp>(nested))
2197  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2198  "'omp.wsloop'";
2199  } else if (isComposite()) {
2200  return emitError()
2201  << "'omp.composite' attribute present in non-composite wrapper";
2202  }
2203 
2204  return success();
2205 }
2206 
2207 //===----------------------------------------------------------------------===//
2208 // DeclareReductionOp
2209 //===----------------------------------------------------------------------===//
2210 
2211 LogicalResult DeclareReductionOp::verifyRegions() {
2212  if (!getAllocRegion().empty()) {
2213  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2214  if (yieldOp.getResults().size() != 1 ||
2215  yieldOp.getResults().getTypes()[0] != getType())
2216  return emitOpError() << "expects alloc region to yield a value "
2217  "of the reduction type";
2218  }
2219  }
2220 
2221  if (getInitializerRegion().empty())
2222  return emitOpError() << "expects non-empty initializer region";
2223  Block &initializerEntryBlock = getInitializerRegion().front();
2224 
2225  if (initializerEntryBlock.getNumArguments() == 1) {
2226  if (!getAllocRegion().empty())
2227  return emitOpError() << "expects two arguments to the initializer region "
2228  "when an allocation region is used";
2229  } else if (initializerEntryBlock.getNumArguments() == 2) {
2230  if (getAllocRegion().empty())
2231  return emitOpError() << "expects one argument to the initializer region "
2232  "when no allocation region is used";
2233  } else {
2234  return emitOpError()
2235  << "expects one or two arguments to the initializer region";
2236  }
2237 
2238  for (mlir::Value arg : initializerEntryBlock.getArguments())
2239  if (arg.getType() != getType())
2240  return emitOpError() << "expects initializer region argument to match "
2241  "the reduction type";
2242 
2243  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2244  if (yieldOp.getResults().size() != 1 ||
2245  yieldOp.getResults().getTypes()[0] != getType())
2246  return emitOpError() << "expects initializer region to yield a value "
2247  "of the reduction type";
2248  }
2249 
2250  if (getReductionRegion().empty())
2251  return emitOpError() << "expects non-empty reduction region";
2252  Block &reductionEntryBlock = getReductionRegion().front();
2253  if (reductionEntryBlock.getNumArguments() != 2 ||
2254  reductionEntryBlock.getArgumentTypes()[0] !=
2255  reductionEntryBlock.getArgumentTypes()[1] ||
2256  reductionEntryBlock.getArgumentTypes()[0] != getType())
2257  return emitOpError() << "expects reduction region with two arguments of "
2258  "the reduction type";
2259  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2260  if (yieldOp.getResults().size() != 1 ||
2261  yieldOp.getResults().getTypes()[0] != getType())
2262  return emitOpError() << "expects reduction region to yield a value "
2263  "of the reduction type";
2264  }
2265 
2266  if (!getAtomicReductionRegion().empty()) {
2267  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2268  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2269  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2270  atomicReductionEntryBlock.getArgumentTypes()[1])
2271  return emitOpError() << "expects atomic reduction region with two "
2272  "arguments of the same type";
2273  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2274  atomicReductionEntryBlock.getArgumentTypes()[0]);
2275  if (!ptrType ||
2276  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2277  return emitOpError() << "expects atomic reduction region arguments to "
2278  "be accumulators containing the reduction type";
2279  }
2280 
2281  if (getCleanupRegion().empty())
2282  return success();
2283  Block &cleanupEntryBlock = getCleanupRegion().front();
2284  if (cleanupEntryBlock.getNumArguments() != 1 ||
2285  cleanupEntryBlock.getArgument(0).getType() != getType())
2286  return emitOpError() << "expects cleanup region with one argument "
2287  "of the reduction type";
2288 
2289  return success();
2290 }
2291 
2292 //===----------------------------------------------------------------------===//
2293 // TaskOp
2294 //===----------------------------------------------------------------------===//
2295 
2296 void TaskOp::build(OpBuilder &builder, OperationState &state,
2297  const TaskOperands &clauses) {
2298  MLIRContext *ctx = builder.getContext();
2299  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2300  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2301  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2302  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2303  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2304  clauses.priority, /*private_vars=*/clauses.privateVars,
2305  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2306  clauses.untied, clauses.eventHandle);
2307 }
2308 
2309 LogicalResult TaskOp::verify() {
2310  LogicalResult verifyDependVars =
2311  verifyDependVarList(*this, getDependKinds(), getDependVars());
2312  return failed(verifyDependVars)
2313  ? verifyDependVars
2314  : verifyReductionVarList(*this, getInReductionSyms(),
2315  getInReductionVars(),
2316  getInReductionByref());
2317 }
2318 
2319 //===----------------------------------------------------------------------===//
2320 // TaskgroupOp
2321 //===----------------------------------------------------------------------===//
2322 
2323 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2324  const TaskgroupOperands &clauses) {
2325  MLIRContext *ctx = builder.getContext();
2326  TaskgroupOp::build(builder, state, clauses.allocateVars,
2327  clauses.allocatorVars, clauses.taskReductionVars,
2328  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2329  makeArrayAttr(ctx, clauses.taskReductionSyms));
2330 }
2331 
2332 LogicalResult TaskgroupOp::verify() {
2333  return verifyReductionVarList(*this, getTaskReductionSyms(),
2334  getTaskReductionVars(),
2335  getTaskReductionByref());
2336 }
2337 
2338 //===----------------------------------------------------------------------===//
2339 // TaskloopOp
2340 //===----------------------------------------------------------------------===//
2341 
2342 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2343  const TaskloopOperands &clauses) {
2344  MLIRContext *ctx = builder.getContext();
2345  // TODO Store clauses in op: privateVars, privateSyms.
2346  TaskloopOp::build(
2347  builder, state, clauses.allocateVars, clauses.allocatorVars,
2348  clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2349  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2350  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2351  clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2352  /*private_syms=*/nullptr, clauses.reductionVars,
2353  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2354  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2355 }
2356 
2357 SmallVector<Value> TaskloopOp::getAllReductionVars() {
2358  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
2359  getInReductionVars().end());
2360  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2361  getReductionVars().end());
2362  return allReductionNvars;
2363 }
2364 
2365 LogicalResult TaskloopOp::verify() {
2366  if (getAllocateVars().size() != getAllocatorVars().size())
2367  return emitError(
2368  "expected equal sizes for allocate and allocator variables");
2369  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2370  getReductionVars(), getReductionByref())) ||
2371  failed(verifyReductionVarList(*this, getInReductionSyms(),
2372  getInReductionVars(),
2373  getInReductionByref())))
2374  return failure();
2375 
2376  if (!getReductionVars().empty() && getNogroup())
2377  return emitError("if a reduction clause is present on the taskloop "
2378  "directive, the nogroup clause must not be specified");
2379  for (auto var : getReductionVars()) {
2380  if (llvm::is_contained(getInReductionVars(), var))
2381  return emitError("the same list item cannot appear in both a reduction "
2382  "and an in_reduction clause");
2383  }
2384 
2385  if (getGrainsize() && getNumTasks()) {
2386  return emitError(
2387  "the grainsize clause and num_tasks clause are mutually exclusive and "
2388  "may not appear on the same taskloop directive");
2389  }
2390 
2391  return success();
2392 }
2393 
2394 LogicalResult TaskloopOp::verifyRegions() {
2395  if (LoopWrapperInterface nested = getNestedWrapper()) {
2396  if (!isComposite())
2397  return emitError()
2398  << "'omp.composite' attribute missing from composite wrapper";
2399 
2400  // Check for the allowed leaf constructs that may appear in a composite
2401  // construct directly after TASKLOOP.
2402  if (!isa<SimdOp>(nested))
2403  return emitError() << "only supported nested wrapper is 'omp.simd'";
2404  } else if (isComposite()) {
2405  return emitError()
2406  << "'omp.composite' attribute present in non-composite wrapper";
2407  }
2408 
2409  return success();
2410 }
2411 
2412 //===----------------------------------------------------------------------===//
2413 // LoopNestOp
2414 //===----------------------------------------------------------------------===//
2415 
2416 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2417  // Parse an opening `(` followed by induction variables followed by `)`
2420  Type loopVarType;
2421  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2422  parser.parseColonType(loopVarType) ||
2423  // Parse loop bounds.
2424  parser.parseEqual() ||
2425  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2426  parser.parseKeyword("to") ||
2427  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2428  return failure();
2429 
2430  for (auto &iv : ivs)
2431  iv.type = loopVarType;
2432 
2433  // Parse "inclusive" flag.
2434  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2435  result.addAttribute("loop_inclusive",
2436  UnitAttr::get(parser.getBuilder().getContext()));
2437 
2438  // Parse step values.
2440  if (parser.parseKeyword("step") ||
2441  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2442  return failure();
2443 
2444  // Parse the body.
2445  Region *region = result.addRegion();
2446  if (parser.parseRegion(*region, ivs))
2447  return failure();
2448 
2449  // Resolve operands.
2450  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2451  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2452  parser.resolveOperands(steps, loopVarType, result.operands))
2453  return failure();
2454 
2455  // Parse the optional attribute list.
2456  return parser.parseOptionalAttrDict(result.attributes);
2457 }
2458 
2460  Region &region = getRegion();
2461  auto args = region.getArguments();
2462  p << " (" << args << ") : " << args[0].getType() << " = ("
2463  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2464  if (getLoopInclusive())
2465  p << "inclusive ";
2466  p << "step (" << getLoopSteps() << ") ";
2467  p.printRegion(region, /*printEntryBlockArgs=*/false);
2468 }
2469 
2470 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2471  const LoopNestOperands &clauses) {
2472  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2473  clauses.loopUpperBounds, clauses.loopSteps,
2474  clauses.loopInclusive);
2475 }
2476 
2477 LogicalResult LoopNestOp::verify() {
2478  if (getLoopLowerBounds().empty())
2479  return emitOpError() << "must represent at least one loop";
2480 
2481  if (getLoopLowerBounds().size() != getIVs().size())
2482  return emitOpError() << "number of range arguments and IVs do not match";
2483 
2484  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2485  if (lb.getType() != iv.getType())
2486  return emitOpError()
2487  << "range argument type does not match corresponding IV type";
2488  }
2489 
2490  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2491  return emitOpError() << "expects parent op to be a loop wrapper";
2492 
2493  return success();
2494 }
2495 
2496 void LoopNestOp::gatherWrappers(
2498  Operation *parent = (*this)->getParentOp();
2499  while (auto wrapper =
2500  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2501  wrappers.push_back(wrapper);
2502  parent = parent->getParentOp();
2503  }
2504 }
2505 
2506 //===----------------------------------------------------------------------===//
2507 // Critical construct (2.17.1)
2508 //===----------------------------------------------------------------------===//
2509 
2510 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2511  const CriticalDeclareOperands &clauses) {
2512  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2513 }
2514 
2515 LogicalResult CriticalDeclareOp::verify() {
2516  return verifySynchronizationHint(*this, getHint());
2517 }
2518 
2519 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2520  if (getNameAttr()) {
2521  SymbolRefAttr symbolRef = getNameAttr();
2522  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2523  *this, symbolRef);
2524  if (!decl) {
2525  return emitOpError() << "expected symbol reference " << symbolRef
2526  << " to point to a critical declaration";
2527  }
2528  }
2529 
2530  return success();
2531 }
2532 
2533 //===----------------------------------------------------------------------===//
2534 // Ordered construct
2535 //===----------------------------------------------------------------------===//
2536 
2537 static LogicalResult verifyOrderedParent(Operation &op) {
2538  bool hasRegion = op.getNumRegions() > 0;
2539  auto loopOp = op.getParentOfType<LoopNestOp>();
2540  if (!loopOp) {
2541  if (hasRegion)
2542  return success();
2543 
2544  // TODO: Consider if this needs to be the case only for the standalone
2545  // variant of the ordered construct.
2546  return op.emitOpError() << "must be nested inside of a loop";
2547  }
2548 
2549  Operation *wrapper = loopOp->getParentOp();
2550  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2551  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2552  if (!orderedAttr)
2553  return op.emitOpError() << "the enclosing worksharing-loop region must "
2554  "have an ordered clause";
2555 
2556  if (hasRegion && orderedAttr.getInt() != 0)
2557  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2558  "have a parameter present";
2559 
2560  if (!hasRegion && orderedAttr.getInt() == 0)
2561  return op.emitOpError() << "the enclosing loop's ordered clause must "
2562  "have a parameter present";
2563  } else if (!isa<SimdOp>(wrapper)) {
2564  return op.emitOpError() << "must be nested inside of a worksharing, simd "
2565  "or worksharing simd loop";
2566  }
2567  return success();
2568 }
2569 
2570 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2571  const OrderedOperands &clauses) {
2572  OrderedOp::build(builder, state, clauses.doacrossDependType,
2573  clauses.doacrossNumLoops, clauses.doacrossDependVars);
2574 }
2575 
2576 LogicalResult OrderedOp::verify() {
2577  if (failed(verifyOrderedParent(**this)))
2578  return failure();
2579 
2580  auto wrapper = (*this)->getParentOfType<WsloopOp>();
2581  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2582  return emitOpError() << "number of variables in depend clause does not "
2583  << "match number of iteration variables in the "
2584  << "doacross loop";
2585 
2586  return success();
2587 }
2588 
2589 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2590  const OrderedRegionOperands &clauses) {
2591  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2592 }
2593 
2594 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
2595 
2596 //===----------------------------------------------------------------------===//
2597 // TaskwaitOp
2598 //===----------------------------------------------------------------------===//
2599 
2600 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2601  const TaskwaitOperands &clauses) {
2602  // TODO Store clauses in op: dependKinds, dependVars, nowait.
2603  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
2604  /*depend_vars=*/{}, /*nowait=*/nullptr);
2605 }
2606 
2607 //===----------------------------------------------------------------------===//
2608 // Verifier for AtomicReadOp
2609 //===----------------------------------------------------------------------===//
2610 
2611 LogicalResult AtomicReadOp::verify() {
2612  if (verifyCommon().failed())
2613  return mlir::failure();
2614 
2615  if (auto mo = getMemoryOrder()) {
2616  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2617  *mo == ClauseMemoryOrderKind::Release) {
2618  return emitError(
2619  "memory-order must not be acq_rel or release for atomic reads");
2620  }
2621  }
2622  return verifySynchronizationHint(*this, getHint());
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // Verifier for AtomicWriteOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 LogicalResult AtomicWriteOp::verify() {
2630  if (verifyCommon().failed())
2631  return mlir::failure();
2632 
2633  if (auto mo = getMemoryOrder()) {
2634  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2635  *mo == ClauseMemoryOrderKind::Acquire) {
2636  return emitError(
2637  "memory-order must not be acq_rel or acquire for atomic writes");
2638  }
2639  }
2640  return verifySynchronizationHint(*this, getHint());
2641 }
2642 
2643 //===----------------------------------------------------------------------===//
2644 // Verifier for AtomicUpdateOp
2645 //===----------------------------------------------------------------------===//
2646 
2647 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2648  PatternRewriter &rewriter) {
2649  if (op.isNoOp()) {
2650  rewriter.eraseOp(op);
2651  return success();
2652  }
2653  if (Value writeVal = op.getWriteOpVal()) {
2654  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
2655  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2656  return success();
2657  }
2658  return failure();
2659 }
2660 
2661 LogicalResult AtomicUpdateOp::verify() {
2662  if (verifyCommon().failed())
2663  return mlir::failure();
2664 
2665  if (auto mo = getMemoryOrder()) {
2666  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2667  *mo == ClauseMemoryOrderKind::Acquire) {
2668  return emitError(
2669  "memory-order must not be acq_rel or acquire for atomic updates");
2670  }
2671  }
2672 
2673  return verifySynchronizationHint(*this, getHint());
2674 }
2675 
2676 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2677 
2678 //===----------------------------------------------------------------------===//
2679 // Verifier for AtomicCaptureOp
2680 //===----------------------------------------------------------------------===//
2681 
2682 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2683  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2684  return op;
2685  return dyn_cast<AtomicReadOp>(getSecondOp());
2686 }
2687 
2688 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2689  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2690  return op;
2691  return dyn_cast<AtomicWriteOp>(getSecondOp());
2692 }
2693 
2694 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2695  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2696  return op;
2697  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2698 }
2699 
2700 LogicalResult AtomicCaptureOp::verify() {
2701  return verifySynchronizationHint(*this, getHint());
2702 }
2703 
2704 LogicalResult AtomicCaptureOp::verifyRegions() {
2705  if (verifyRegionsCommon().failed())
2706  return mlir::failure();
2707 
2708  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
2709  return emitOpError(
2710  "operations inside capture region must not have hint clause");
2711 
2712  if (getFirstOp()->getAttr("memory_order") ||
2713  getSecondOp()->getAttr("memory_order"))
2714  return emitOpError(
2715  "operations inside capture region must not have memory_order clause");
2716  return success();
2717 }
2718 
2719 //===----------------------------------------------------------------------===//
2720 // CancelOp
2721 //===----------------------------------------------------------------------===//
2722 
2723 void CancelOp::build(OpBuilder &builder, OperationState &state,
2724  const CancelOperands &clauses) {
2725  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
2726 }
2727 
2728 LogicalResult CancelOp::verify() {
2729  ClauseCancellationConstructType cct = getCancelDirective();
2730  Operation *parentOp = (*this)->getParentOp();
2731 
2732  if (!parentOp) {
2733  return emitOpError() << "must be used within a region supporting "
2734  "cancel directive";
2735  }
2736 
2737  if ((cct == ClauseCancellationConstructType::Parallel) &&
2738  !isa<ParallelOp>(parentOp)) {
2739  return emitOpError() << "cancel parallel must appear "
2740  << "inside a parallel region";
2741  }
2742  if (cct == ClauseCancellationConstructType::Loop) {
2743  auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2744  auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2745  loopOp ? loopOp->getParentOp() : nullptr);
2746 
2747  if (!wsloopOp) {
2748  return emitOpError()
2749  << "cancel loop must appear inside a worksharing-loop region";
2750  }
2751  if (wsloopOp.getNowaitAttr()) {
2752  return emitError() << "A worksharing construct that is canceled "
2753  << "must not have a nowait clause";
2754  }
2755  if (wsloopOp.getOrderedAttr()) {
2756  return emitError() << "A worksharing construct that is canceled "
2757  << "must not have an ordered clause";
2758  }
2759 
2760  } else if (cct == ClauseCancellationConstructType::Sections) {
2761  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2762  return emitOpError() << "cancel sections must appear "
2763  << "inside a sections region";
2764  }
2765  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2766  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2767  return emitError() << "A sections construct that is canceled "
2768  << "must not have a nowait clause";
2769  }
2770  }
2771  // TODO : Add more when we support taskgroup.
2772  return success();
2773 }
2774 
2775 //===----------------------------------------------------------------------===//
2776 // CancellationPointOp
2777 //===----------------------------------------------------------------------===//
2778 
2779 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
2780  const CancellationPointOperands &clauses) {
2781  CancellationPointOp::build(builder, state, clauses.cancelDirective);
2782 }
2783 
2784 LogicalResult CancellationPointOp::verify() {
2785  ClauseCancellationConstructType cct = getCancelDirective();
2786  Operation *parentOp = (*this)->getParentOp();
2787 
2788  if (!parentOp) {
2789  return emitOpError() << "must be used within a region supporting "
2790  "cancellation point directive";
2791  }
2792 
2793  if ((cct == ClauseCancellationConstructType::Parallel) &&
2794  !(isa<ParallelOp>(parentOp))) {
2795  return emitOpError() << "cancellation point parallel must appear "
2796  << "inside a parallel region";
2797  }
2798  if ((cct == ClauseCancellationConstructType::Loop) &&
2799  (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
2800  return emitOpError() << "cancellation point loop must appear "
2801  << "inside a worksharing-loop region";
2802  }
2803  if ((cct == ClauseCancellationConstructType::Sections) &&
2804  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2805  return emitOpError() << "cancellation point sections must appear "
2806  << "inside a sections region";
2807  }
2808  // TODO : Add more when we support taskgroup.
2809  return success();
2810 }
2811 
2812 //===----------------------------------------------------------------------===//
2813 // MapBoundsOp
2814 //===----------------------------------------------------------------------===//
2815 
2816 LogicalResult MapBoundsOp::verify() {
2817  auto extent = getExtent();
2818  auto upperbound = getUpperBound();
2819  if (!extent && !upperbound)
2820  return emitError("expected extent or upperbound.");
2821  return success();
2822 }
2823 
2824 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2825  TypeRange /*result_types*/, StringAttr symName,
2826  TypeAttr type) {
2827  PrivateClauseOp::build(
2828  odsBuilder, odsState, symName, type,
2830  DataSharingClauseType::Private));
2831 }
2832 
2833 LogicalResult PrivateClauseOp::verifyRegions() {
2834  Type symType = getType();
2835 
2836  auto verifyTerminator = [&](Operation *terminator,
2837  bool yieldsValue) -> LogicalResult {
2838  if (!terminator->getBlock()->getSuccessors().empty())
2839  return success();
2840 
2841  if (!llvm::isa<YieldOp>(terminator))
2842  return mlir::emitError(terminator->getLoc())
2843  << "expected exit block terminator to be an `omp.yield` op.";
2844 
2845  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2846  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2847 
2848  if (!yieldsValue) {
2849  if (yieldedTypes.empty())
2850  return success();
2851 
2852  return mlir::emitError(terminator->getLoc())
2853  << "Did not expect any values to be yielded.";
2854  }
2855 
2856  if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2857  return success();
2858 
2859  auto error = mlir::emitError(yieldOp.getLoc())
2860  << "Invalid yielded value. Expected type: " << symType
2861  << ", got: ";
2862 
2863  if (yieldedTypes.empty())
2864  error << "None";
2865  else
2866  error << yieldedTypes;
2867 
2868  return error;
2869  };
2870 
2871  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
2872  StringRef regionName,
2873  bool yieldsValue) -> LogicalResult {
2874  assert(!region.empty());
2875 
2876  if (region.getNumArguments() != expectedNumArgs)
2877  return mlir::emitError(region.getLoc())
2878  << "`" << regionName << "`: "
2879  << "expected " << expectedNumArgs
2880  << " region arguments, got: " << region.getNumArguments();
2881 
2882  for (Block &block : region) {
2883  // MLIR will verify the absence of the terminator for us.
2884  if (!block.mightHaveTerminator())
2885  continue;
2886 
2887  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2888  return failure();
2889  }
2890 
2891  return success();
2892  };
2893 
2894  if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc",
2895  /*yieldsValue=*/true)))
2896  return failure();
2897 
2898  DataSharingClauseType dsType = getDataSharingType();
2899 
2900  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2901  return emitError("`private` clauses require only an `alloc` region.");
2902 
2903  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2904  return emitError(
2905  "`firstprivate` clauses require both `alloc` and `copy` regions.");
2906 
2907  if (dsType == DataSharingClauseType::FirstPrivate &&
2908  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
2909  /*yieldsValue=*/true)))
2910  return failure();
2911 
2912  if (!getDeallocRegion().empty() &&
2913  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
2914  /*yieldsValue=*/false)))
2915  return failure();
2916 
2917  return success();
2918 }
2919 
2920 //===----------------------------------------------------------------------===//
2921 // Spec 5.2: Masked construct (10.5)
2922 //===----------------------------------------------------------------------===//
2923 
2924 void MaskedOp::build(OpBuilder &builder, OperationState &state,
2925  const MaskedOperands &clauses) {
2926  MaskedOp::build(builder, state, clauses.filteredThreadId);
2927 }
2928 
2929 #define GET_ATTRDEF_CLASSES
2930 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2931 
2932 #define GET_OP_CLASSES
2933 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2934 
2935 #define GET_TYPEDEF_CLASSES
2936 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:722
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static ParseResult parseInReductionMapPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionMapPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:760
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:871
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.