MLIR  21.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35 #include <cstddef>
36 #include <iterator>
37 #include <optional>
38 #include <variant>
39 
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::omp;
47 
48 static ArrayAttr makeArrayAttr(MLIRContext *context,
50  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51 }
52 
53 static DenseBoolArrayAttr
55  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56 }
57 
58 namespace {
59 struct MemRefPointerLikeModel
60  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61  MemRefType> {
62  Type getElementType(Type pointer) const {
63  return llvm::cast<MemRefType>(pointer).getElementType();
64  }
65 };
66 
67 struct LLVMPointerPointerLikeModel
68  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69  LLVM::LLVMPointerType> {
70  Type getElementType(Type pointer) const { return Type(); }
71 };
72 } // namespace
73 
74 void OpenMPDialect::initialize() {
75  addOperations<
76 #define GET_OP_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78  >();
79  addAttributes<
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82  >();
83  addTypes<
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86  >();
87 
88  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89 
90  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92  *getContext());
93 
94  // Attach default offload module interface to module op to access
95  // offload functionality through
96  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97  *getContext());
98 
99  // Attach default declare target interfaces to operations which can be marked
100  // as declare target (Global Operations and Functions/Subroutines in dialects
101  // that Fortran (or other languages that lower to MLIR) translates too
102  mlir::LLVM::GlobalOp::attachInterface<
104  *getContext());
105  mlir::LLVM::LLVMFuncOp::attachInterface<
107  *getContext());
108  mlir::func::FuncOp::attachInterface<
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parser and printer for Allocate Clause
114 //===----------------------------------------------------------------------===//
115 
116 /// Parse an allocate clause with allocators and a list of operands with types.
117 ///
118 /// allocate-operand-list :: = allocate-operand |
119 /// allocator-operand `,` allocate-operand-list
120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121 /// ssa-id-and-type ::= ssa-id `:` type
122 static ParseResult parseAllocateAndAllocator(
123  OpAsmParser &parser,
125  SmallVectorImpl<Type> &allocateTypes,
127  SmallVectorImpl<Type> &allocatorTypes) {
128 
129  return parser.parseCommaSeparatedList([&]() {
131  Type type;
132  if (parser.parseOperand(operand) || parser.parseColonType(type))
133  return failure();
134  allocatorVars.push_back(operand);
135  allocatorTypes.push_back(type);
136  if (parser.parseArrow())
137  return failure();
138  if (parser.parseOperand(operand) || parser.parseColonType(type))
139  return failure();
140 
141  allocateVars.push_back(operand);
142  allocateTypes.push_back(type);
143  return success();
144  });
145 }
146 
147 /// Print allocate clause
149  OperandRange allocateVars,
150  TypeRange allocateTypes,
151  OperandRange allocatorVars,
152  TypeRange allocatorTypes) {
153  for (unsigned i = 0; i < allocateVars.size(); ++i) {
154  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157  }
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parser and printer for a clause attribute (StringEnumAttr)
162 //===----------------------------------------------------------------------===//
163 
164 template <typename ClauseAttr>
165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167  StringRef enumStr;
168  SMLoc loc = parser.getCurrentLocation();
169  if (parser.parseKeyword(&enumStr))
170  return failure();
171  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172  attr = ClauseAttr::get(parser.getContext(), *enumValue);
173  return success();
174  }
175  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
176 }
177 
178 template <typename ClauseAttr>
179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180  p << stringifyEnum(attr.getValue());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Parser and printer for Linear Clause
185 //===----------------------------------------------------------------------===//
186 
187 /// linear ::= `linear` `(` linear-list `)`
188 /// linear-list := linear-val | linear-val linear-list
189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
190 static ParseResult parseLinearClause(
191  OpAsmParser &parser,
193  SmallVectorImpl<Type> &linearTypes,
195  return parser.parseCommaSeparatedList([&]() {
197  Type type;
199  if (parser.parseOperand(var) || parser.parseEqual() ||
200  parser.parseOperand(stepVar) || parser.parseColonType(type))
201  return failure();
202 
203  linearVars.push_back(var);
204  linearTypes.push_back(type);
205  linearStepVars.push_back(stepVar);
206  return success();
207  });
208 }
209 
210 /// Print Linear Clause
212  ValueRange linearVars, TypeRange linearTypes,
213  ValueRange linearStepVars) {
214  size_t linearVarsSize = linearVars.size();
215  for (unsigned i = 0; i < linearVarsSize; ++i) {
216  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217  p << linearVars[i];
218  if (linearStepVars.size() > i)
219  p << " = " << linearStepVars[i];
220  p << " : " << linearVars[i].getType() << separator;
221  }
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Verifier for Nontemporal Clause
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verifyNontemporalClause(Operation *op,
229  OperandRange nontemporalVars) {
230 
231  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232  DenseSet<Value> nontemporalItems;
233  for (const auto &it : nontemporalVars)
234  if (!nontemporalItems.insert(it).second)
235  return op->emitOpError() << "nontemporal variable used more than once";
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Parser, verifier and printer for Aligned Clause
242 //===----------------------------------------------------------------------===//
243 static LogicalResult verifyAlignedClause(Operation *op,
244  std::optional<ArrayAttr> alignments,
245  OperandRange alignedVars) {
246  // Check if number of alignment values equals to number of aligned variables
247  if (!alignedVars.empty()) {
248  if (!alignments || alignments->size() != alignedVars.size())
249  return op->emitOpError()
250  << "expected as many alignment values as aligned variables";
251  } else {
252  if (alignments)
253  return op->emitOpError() << "unexpected alignment values attribute";
254  return success();
255  }
256 
257  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258  DenseSet<Value> alignedItems;
259  for (auto it : alignedVars)
260  if (!alignedItems.insert(it).second)
261  return op->emitOpError() << "aligned variable used more than once";
262 
263  if (!alignments)
264  return success();
265 
266  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267  for (unsigned i = 0; i < (*alignments).size(); ++i) {
268  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269  if (intAttr.getValue().sle(0))
270  return op->emitOpError() << "alignment should be greater than 0";
271  } else {
272  return op->emitOpError() << "expected integer alignment";
273  }
274  }
275 
276  return success();
277 }
278 
279 /// aligned ::= `aligned` `(` aligned-list `)`
280 /// aligned-list := aligned-val | aligned-val aligned-list
281 /// aligned-val := ssa-id-and-type `->` alignment
282 static ParseResult
285  SmallVectorImpl<Type> &alignedTypes,
286  ArrayAttr &alignmentsAttr) {
287  SmallVector<Attribute> alignmentVec;
288  if (failed(parser.parseCommaSeparatedList([&]() {
289  if (parser.parseOperand(alignedVars.emplace_back()) ||
290  parser.parseColonType(alignedTypes.emplace_back()) ||
291  parser.parseArrow() ||
292  parser.parseAttribute(alignmentVec.emplace_back())) {
293  return failure();
294  }
295  return success();
296  })))
297  return failure();
298  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300  return success();
301 }
302 
303 /// Print Aligned Clause
305  ValueRange alignedVars, TypeRange alignedTypes,
306  std::optional<ArrayAttr> alignments) {
307  for (unsigned i = 0; i < alignedVars.size(); ++i) {
308  if (i != 0)
309  p << ", ";
310  p << alignedVars[i] << " : " << alignedVars[i].getType();
311  p << " -> " << (*alignments)[i];
312  }
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Parser, printer and verifier for Schedule Clause
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult
321  SmallVectorImpl<SmallString<12>> &modifiers) {
322  if (modifiers.size() > 2)
323  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
324  for (const auto &mod : modifiers) {
325  // Translate the string. If it has no value, then it was not a valid
326  // modifier!
327  auto symbol = symbolizeScheduleModifier(mod);
328  if (!symbol)
329  return parser.emitError(parser.getNameLoc())
330  << " unknown modifier type: " << mod;
331  }
332 
333  // If we have one modifier that is "simd", then stick a "none" modiifer in
334  // index 0.
335  if (modifiers.size() == 1) {
336  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337  modifiers.push_back(modifiers[0]);
338  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339  }
340  } else if (modifiers.size() == 2) {
341  // If there are two modifier:
342  // First modifier should not be simd, second one should be simd
343  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345  return parser.emitError(parser.getNameLoc())
346  << " incorrect modifier order";
347  }
348  return success();
349 }
350 
351 /// schedule ::= `schedule` `(` sched-list `)`
352 /// sched-list ::= sched-val | sched-val sched-list |
353 /// sched-val `,` sched-modifier
354 /// sched-val ::= sched-with-chunk | sched-wo-chunk
355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357 /// sched-wo-chunk ::= `auto` | `runtime`
358 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
359 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
360 static ParseResult
361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364  Type &chunkType) {
365  StringRef keyword;
366  if (parser.parseKeyword(&keyword))
367  return failure();
368  std::optional<mlir::omp::ClauseScheduleKind> schedule =
369  symbolizeClauseScheduleKind(keyword);
370  if (!schedule)
371  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374  switch (*schedule) {
375  case ClauseScheduleKind::Static:
376  case ClauseScheduleKind::Dynamic:
377  case ClauseScheduleKind::Guided:
378  if (succeeded(parser.parseOptionalEqual())) {
379  chunkSize = OpAsmParser::UnresolvedOperand{};
380  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381  return failure();
382  } else {
383  chunkSize = std::nullopt;
384  }
385  break;
386  case ClauseScheduleKind::Auto:
388  chunkSize = std::nullopt;
389  }
390 
391  // If there is a comma, we have one or more modifiers..
392  SmallVector<SmallString<12>> modifiers;
393  while (succeeded(parser.parseOptionalComma())) {
394  StringRef mod;
395  if (parser.parseKeyword(&mod))
396  return failure();
397  modifiers.push_back(mod);
398  }
399 
400  if (verifyScheduleModifiers(parser, modifiers))
401  return failure();
402 
403  if (!modifiers.empty()) {
404  SMLoc loc = parser.getCurrentLocation();
405  if (std::optional<ScheduleModifier> mod =
406  symbolizeScheduleModifier(modifiers[0])) {
407  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408  } else {
409  return parser.emitError(loc, "invalid schedule modifier");
410  }
411  // Only SIMD attribute is allowed here!
412  if (modifiers.size() > 1) {
413  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415  }
416  }
417 
418  return success();
419 }
420 
421 /// Print schedule clause
423  ClauseScheduleKindAttr scheduleKind,
424  ScheduleModifierAttr scheduleMod,
425  UnitAttr scheduleSimd, Value scheduleChunk,
426  Type scheduleChunkType) {
427  p << stringifyClauseScheduleKind(scheduleKind.getValue());
428  if (scheduleChunk)
429  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430  if (scheduleMod)
431  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432  if (scheduleSimd)
433  p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser and printer for Order Clause
438 //===----------------------------------------------------------------------===//
439 
440 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
441 // order-modifier ::= reproducible | unconstrained
442 static ParseResult parseOrderClause(OpAsmParser &parser,
443  ClauseOrderKindAttr &order,
444  OrderModifierAttr &orderMod) {
445  StringRef enumStr;
446  SMLoc loc = parser.getCurrentLocation();
447  if (parser.parseKeyword(&enumStr))
448  return failure();
449  if (std::optional<OrderModifier> enumValue =
450  symbolizeOrderModifier(enumStr)) {
451  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452  if (parser.parseOptionalColon())
453  return failure();
454  loc = parser.getCurrentLocation();
455  if (parser.parseKeyword(&enumStr))
456  return failure();
457  }
458  if (std::optional<ClauseOrderKind> enumValue =
459  symbolizeClauseOrderKind(enumStr)) {
460  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461  return success();
462  }
463  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
464 }
465 
467  ClauseOrderKindAttr order,
468  OrderModifierAttr orderMod) {
469  if (orderMod)
470  p << stringifyOrderModifier(orderMod.getValue()) << ":";
471  if (order)
472  p << stringifyClauseOrderKind(order.getValue());
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Parsers for operations including clauses that define entry block arguments.
477 //===----------------------------------------------------------------------===//
478 
479 namespace {
480 struct MapParseArgs {
482  SmallVectorImpl<Type> &types;
484  SmallVectorImpl<Type> &types)
485  : vars(vars), types(types) {}
486 };
487 struct PrivateParseArgs {
490  ArrayAttr &syms;
491  DenseI64ArrayAttr *mapIndices;
493  SmallVectorImpl<Type> &types, ArrayAttr &syms,
494  DenseI64ArrayAttr *mapIndices = nullptr)
495  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
496 };
497 
498 struct ReductionParseArgs {
500  SmallVectorImpl<Type> &types;
501  DenseBoolArrayAttr &byref;
502  ArrayAttr &syms;
503  ReductionModifierAttr *modifier;
504  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
506  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
507  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
508 };
509 
510 struct AllRegionParseArgs {
511  std::optional<MapParseArgs> hostEvalArgs;
512  std::optional<ReductionParseArgs> inReductionArgs;
513  std::optional<MapParseArgs> mapArgs;
514  std::optional<PrivateParseArgs> privateArgs;
515  std::optional<ReductionParseArgs> reductionArgs;
516  std::optional<ReductionParseArgs> taskReductionArgs;
517  std::optional<MapParseArgs> useDeviceAddrArgs;
518  std::optional<MapParseArgs> useDevicePtrArgs;
519 };
520 } // namespace
521 
522 static ParseResult parseClauseWithRegionArgs(
523  OpAsmParser &parser,
525  SmallVectorImpl<Type> &types,
526  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
527  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
528  DenseBoolArrayAttr *byref = nullptr,
529  ReductionModifierAttr *modifier = nullptr) {
530  SmallVector<SymbolRefAttr> symbolVec;
531  SmallVector<int64_t> mapIndicesVec;
532  SmallVector<bool> isByRefVec;
533  unsigned regionArgOffset = regionPrivateArgs.size();
534 
535  if (parser.parseLParen())
536  return failure();
537 
538  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
539  StringRef enumStr;
540  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
541  parser.parseComma())
542  return failure();
543  std::optional<ReductionModifier> enumValue =
544  symbolizeReductionModifier(enumStr);
545  if (!enumValue.has_value())
546  return failure();
547  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
548  if (!*modifier)
549  return failure();
550  }
551 
552  if (parser.parseCommaSeparatedList([&]() {
553  if (byref)
554  isByRefVec.push_back(
555  parser.parseOptionalKeyword("byref").succeeded());
556 
557  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
558  return failure();
559 
560  if (parser.parseOperand(operands.emplace_back()) ||
561  parser.parseArrow() ||
562  parser.parseArgument(regionPrivateArgs.emplace_back()))
563  return failure();
564 
565  if (mapIndices) {
566  if (parser.parseOptionalLSquare().succeeded()) {
567  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
568  parser.parseInteger(mapIndicesVec.emplace_back()) ||
569  parser.parseRSquare())
570  return failure();
571  } else
572  mapIndicesVec.push_back(-1);
573  }
574 
575  return success();
576  }))
577  return failure();
578 
579  if (parser.parseColon())
580  return failure();
581 
582  if (parser.parseCommaSeparatedList([&]() {
583  if (parser.parseType(types.emplace_back()))
584  return failure();
585 
586  return success();
587  }))
588  return failure();
589 
590  if (operands.size() != types.size())
591  return failure();
592 
593  if (parser.parseRParen())
594  return failure();
595 
596  auto *argsBegin = regionPrivateArgs.begin();
597  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
598  argsBegin + regionArgOffset + types.size());
599  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
600  prv.type = type;
601  }
602 
603  if (symbols) {
604  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
605  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
606  }
607 
608  if (!mapIndicesVec.empty())
609  *mapIndices =
610  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
611 
612  if (byref)
613  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
614 
615  return success();
616 }
617 
618 static ParseResult parseBlockArgClause(
619  OpAsmParser &parser,
621  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
622  if (succeeded(parser.parseOptionalKeyword(keyword))) {
623  if (!mapArgs)
624  return failure();
625 
626  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
627  entryBlockArgs)))
628  return failure();
629  }
630  return success();
631 }
632 
633 static ParseResult parseBlockArgClause(
634  OpAsmParser &parser,
636  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
637  if (succeeded(parser.parseOptionalKeyword(keyword))) {
638  if (!privateArgs)
639  return failure();
640 
641  if (failed(parseClauseWithRegionArgs(
642  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
643  &privateArgs->syms, privateArgs->mapIndices)))
644  return failure();
645  }
646  return success();
647 }
648 
649 static ParseResult parseBlockArgClause(
650  OpAsmParser &parser,
652  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
653  if (succeeded(parser.parseOptionalKeyword(keyword))) {
654  if (!reductionArgs)
655  return failure();
656  if (failed(parseClauseWithRegionArgs(
657  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
658  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
659  reductionArgs->modifier)))
660  return failure();
661  }
662  return success();
663 }
664 
665 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
666  AllRegionParseArgs args) {
668 
669  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
670  args.hostEvalArgs)))
671  return parser.emitError(parser.getCurrentLocation())
672  << "invalid `host_eval` format";
673 
674  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
675  args.inReductionArgs)))
676  return parser.emitError(parser.getCurrentLocation())
677  << "invalid `in_reduction` format";
678 
679  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
680  args.mapArgs)))
681  return parser.emitError(parser.getCurrentLocation())
682  << "invalid `map_entries` format";
683 
684  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
685  args.privateArgs)))
686  return parser.emitError(parser.getCurrentLocation())
687  << "invalid `private` format";
688 
689  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
690  args.reductionArgs)))
691  return parser.emitError(parser.getCurrentLocation())
692  << "invalid `reduction` format";
693 
694  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
695  args.taskReductionArgs)))
696  return parser.emitError(parser.getCurrentLocation())
697  << "invalid `task_reduction` format";
698 
699  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
700  args.useDeviceAddrArgs)))
701  return parser.emitError(parser.getCurrentLocation())
702  << "invalid `use_device_addr` format";
703 
704  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
705  args.useDevicePtrArgs)))
706  return parser.emitError(parser.getCurrentLocation())
707  << "invalid `use_device_addr` format";
708 
709  return parser.parseRegion(region, entryBlockArgs);
710 }
711 
713  OpAsmParser &parser, Region &region,
715  SmallVectorImpl<Type> &hostEvalTypes,
717  SmallVectorImpl<Type> &inReductionTypes,
718  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
720  SmallVectorImpl<Type> &mapTypes,
722  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
723  DenseI64ArrayAttr &privateMaps) {
724  AllRegionParseArgs args;
725  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
726  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
727  inReductionByref, inReductionSyms);
728  args.mapArgs.emplace(mapVars, mapTypes);
729  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
730  &privateMaps);
731  return parseBlockArgRegion(parser, region, args);
732 }
733 
734 static ParseResult parseInReductionPrivateRegion(
735  OpAsmParser &parser, Region &region,
737  SmallVectorImpl<Type> &inReductionTypes,
738  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
740  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
741  AllRegionParseArgs args;
742  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
743  inReductionByref, inReductionSyms);
744  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
745  return parseBlockArgRegion(parser, region, args);
746 }
747 
749  OpAsmParser &parser, Region &region,
751  SmallVectorImpl<Type> &inReductionTypes,
752  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
754  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
755  ReductionModifierAttr &reductionMod,
757  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
758  ArrayAttr &reductionSyms) {
759  AllRegionParseArgs args;
760  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
761  inReductionByref, inReductionSyms);
762  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
763  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
764  reductionSyms, &reductionMod);
765  return parseBlockArgRegion(parser, region, args);
766 }
767 
768 static ParseResult parsePrivateRegion(
769  OpAsmParser &parser, Region &region,
771  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
772  AllRegionParseArgs args;
773  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
774  return parseBlockArgRegion(parser, region, args);
775 }
776 
777 static ParseResult parsePrivateReductionRegion(
778  OpAsmParser &parser, Region &region,
780  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
781  ReductionModifierAttr &reductionMod,
783  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
784  ArrayAttr &reductionSyms) {
785  AllRegionParseArgs args;
786  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
787  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
788  reductionSyms, &reductionMod);
789  return parseBlockArgRegion(parser, region, args);
790 }
791 
792 static ParseResult parseTaskReductionRegion(
793  OpAsmParser &parser, Region &region,
795  SmallVectorImpl<Type> &taskReductionTypes,
796  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
797  AllRegionParseArgs args;
798  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
799  taskReductionByref, taskReductionSyms);
800  return parseBlockArgRegion(parser, region, args);
801 }
802 
804  OpAsmParser &parser, Region &region,
806  SmallVectorImpl<Type> &useDeviceAddrTypes,
808  SmallVectorImpl<Type> &useDevicePtrTypes) {
809  AllRegionParseArgs args;
810  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
811  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
812  return parseBlockArgRegion(parser, region, args);
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // Printers for operations including clauses that define entry block arguments.
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 struct MapPrintArgs {
821  ValueRange vars;
822  TypeRange types;
823  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
824 };
825 struct PrivatePrintArgs {
826  ValueRange vars;
827  TypeRange types;
828  ArrayAttr syms;
829  DenseI64ArrayAttr mapIndices;
830  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
831  DenseI64ArrayAttr mapIndices)
832  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
833 };
834 struct ReductionPrintArgs {
835  ValueRange vars;
836  TypeRange types;
837  DenseBoolArrayAttr byref;
838  ArrayAttr syms;
839  ReductionModifierAttr modifier;
840  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
841  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
842  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
843 };
844 struct AllRegionPrintArgs {
845  std::optional<MapPrintArgs> hostEvalArgs;
846  std::optional<ReductionPrintArgs> inReductionArgs;
847  std::optional<MapPrintArgs> mapArgs;
848  std::optional<PrivatePrintArgs> privateArgs;
849  std::optional<ReductionPrintArgs> reductionArgs;
850  std::optional<ReductionPrintArgs> taskReductionArgs;
851  std::optional<MapPrintArgs> useDeviceAddrArgs;
852  std::optional<MapPrintArgs> useDevicePtrArgs;
853 };
854 } // namespace
855 
857  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
858  ValueRange argsSubrange, ValueRange operands, TypeRange types,
859  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
860  DenseBoolArrayAttr byref = nullptr,
861  ReductionModifierAttr modifier = nullptr) {
862  if (argsSubrange.empty())
863  return;
864 
865  p << clauseName << "(";
866 
867  if (modifier)
868  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
869 
870  if (!symbols) {
871  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
872  symbols = ArrayAttr::get(ctx, values);
873  }
874 
875  if (!mapIndices) {
876  llvm::SmallVector<int64_t> values(operands.size(), -1);
877  mapIndices = DenseI64ArrayAttr::get(ctx, values);
878  }
879 
880  if (!byref) {
881  mlir::SmallVector<bool> values(operands.size(), false);
882  byref = DenseBoolArrayAttr::get(ctx, values);
883  }
884 
885  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
886  mapIndices.asArrayRef(),
887  byref.asArrayRef()),
888  p, [&p](auto t) {
889  auto [op, arg, sym, map, isByRef] = t;
890  if (isByRef)
891  p << "byref ";
892  if (sym)
893  p << sym << " ";
894 
895  p << op << " -> " << arg;
896 
897  if (map != -1)
898  p << " [map_idx=" << map << "]";
899  });
900  p << " : ";
901  llvm::interleaveComma(types, p);
902  p << ") ";
903 }
904 
906  StringRef clauseName, ValueRange argsSubrange,
907  std::optional<MapPrintArgs> mapArgs) {
908  if (mapArgs)
909  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
910  mapArgs->types);
911 }
912 
914  StringRef clauseName, ValueRange argsSubrange,
915  std::optional<PrivatePrintArgs> privateArgs) {
916  if (privateArgs)
917  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
918  privateArgs->vars, privateArgs->types,
919  privateArgs->syms, privateArgs->mapIndices);
920 }
921 
922 static void
923 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
924  ValueRange argsSubrange,
925  std::optional<ReductionPrintArgs> reductionArgs) {
926  if (reductionArgs)
927  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
928  reductionArgs->vars, reductionArgs->types,
929  reductionArgs->syms, /*mapIndices=*/nullptr,
930  reductionArgs->byref, reductionArgs->modifier);
931 }
932 
933 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
934  const AllRegionPrintArgs &args) {
935  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
936  MLIRContext *ctx = op->getContext();
937 
938  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
939  args.hostEvalArgs);
940  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
941  args.inReductionArgs);
942  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
943  args.mapArgs);
944  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
945  args.privateArgs);
946  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
947  args.reductionArgs);
948  printBlockArgClause(p, ctx, "task_reduction",
949  iface.getTaskReductionBlockArgs(),
950  args.taskReductionArgs);
951  printBlockArgClause(p, ctx, "use_device_addr",
952  iface.getUseDeviceAddrBlockArgs(),
953  args.useDeviceAddrArgs);
954  printBlockArgClause(p, ctx, "use_device_ptr",
955  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
956 
957  p.printRegion(region, /*printEntryBlockArgs=*/false);
958 }
959 
961  OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
962  TypeRange hostEvalTypes, ValueRange inReductionVars,
963  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
964  ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
965  ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
966  DenseI64ArrayAttr privateMaps) {
967  AllRegionPrintArgs args;
968  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
969  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
970  inReductionByref, inReductionSyms);
971  args.mapArgs.emplace(mapVars, mapTypes);
972  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
973  printBlockArgRegion(p, op, region, args);
974 }
975 
977  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
978  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
979  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
980  ArrayAttr privateSyms) {
981  AllRegionPrintArgs args;
982  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
983  inReductionByref, inReductionSyms);
984  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
985  /*mapIndices=*/nullptr);
986  printBlockArgRegion(p, op, region, args);
987 }
988 
990  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
991  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
992  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
993  ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
994  ValueRange reductionVars, TypeRange reductionTypes,
995  DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
996  AllRegionPrintArgs args;
997  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
998  inReductionByref, inReductionSyms);
999  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1000  /*mapIndices=*/nullptr);
1001  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1002  reductionSyms, reductionMod);
1003  printBlockArgRegion(p, op, region, args);
1004 }
1005 
1006 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1007  ValueRange privateVars, TypeRange privateTypes,
1008  ArrayAttr privateSyms) {
1009  AllRegionPrintArgs args;
1010  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1011  /*mapIndices=*/nullptr);
1012  printBlockArgRegion(p, op, region, args);
1013 }
1014 
1016  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1017  TypeRange privateTypes, ArrayAttr privateSyms,
1018  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1019  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1020  ArrayAttr reductionSyms) {
1021  AllRegionPrintArgs args;
1022  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1023  /*mapIndices=*/nullptr);
1024  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1025  reductionSyms, reductionMod);
1026  printBlockArgRegion(p, op, region, args);
1027 }
1028 
1030  Region &region,
1031  ValueRange taskReductionVars,
1032  TypeRange taskReductionTypes,
1033  DenseBoolArrayAttr taskReductionByref,
1034  ArrayAttr taskReductionSyms) {
1035  AllRegionPrintArgs args;
1036  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1037  taskReductionByref, taskReductionSyms);
1038  printBlockArgRegion(p, op, region, args);
1039 }
1040 
1042  Region &region,
1043  ValueRange useDeviceAddrVars,
1044  TypeRange useDeviceAddrTypes,
1045  ValueRange useDevicePtrVars,
1046  TypeRange useDevicePtrTypes) {
1047  AllRegionPrintArgs args;
1048  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1049  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1050  printBlockArgRegion(p, op, region, args);
1051 }
1052 
1053 /// Verifies Reduction Clause
1054 static LogicalResult
1055 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1056  OperandRange reductionVars,
1057  std::optional<ArrayRef<bool>> reductionByref) {
1058  if (!reductionVars.empty()) {
1059  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1060  return op->emitOpError()
1061  << "expected as many reduction symbol references "
1062  "as reduction variables";
1063  if (reductionByref && reductionByref->size() != reductionVars.size())
1064  return op->emitError() << "expected as many reduction variable by "
1065  "reference attributes as reduction variables";
1066  } else {
1067  if (reductionSyms)
1068  return op->emitOpError() << "unexpected reduction symbol references";
1069  return success();
1070  }
1071 
1072  // TODO: The followings should be done in
1073  // SymbolUserOpInterface::verifySymbolUses.
1074  DenseSet<Value> accumulators;
1075  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1076  Value accum = std::get<0>(args);
1077 
1078  if (!accumulators.insert(accum).second)
1079  return op->emitOpError() << "accumulator variable used more than once";
1080 
1081  Type varType = accum.getType();
1082  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1083  auto decl =
1084  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1085  if (!decl)
1086  return op->emitOpError() << "expected symbol reference " << symbolRef
1087  << " to point to a reduction declaration";
1088 
1089  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1090  return op->emitOpError()
1091  << "expected accumulator (" << varType
1092  << ") to be the same type as reduction declaration ("
1093  << decl.getAccumulatorType() << ")";
1094  }
1095 
1096  return success();
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Parser, printer and verifier for Copyprivate
1101 //===----------------------------------------------------------------------===//
1102 
1103 /// copyprivate-entry-list ::= copyprivate-entry
1104 /// | copyprivate-entry-list `,` copyprivate-entry
1105 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1106 static ParseResult parseCopyprivate(
1107  OpAsmParser &parser,
1109  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1111  if (failed(parser.parseCommaSeparatedList([&]() {
1112  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1113  parser.parseArrow() ||
1114  parser.parseAttribute(symsVec.emplace_back()) ||
1115  parser.parseColonType(copyprivateTypes.emplace_back()))
1116  return failure();
1117  return success();
1118  })))
1119  return failure();
1120  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1121  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1122  return success();
1123 }
1124 
1125 /// Print Copyprivate clause
1127  OperandRange copyprivateVars,
1128  TypeRange copyprivateTypes,
1129  std::optional<ArrayAttr> copyprivateSyms) {
1130  if (!copyprivateSyms.has_value())
1131  return;
1132  llvm::interleaveComma(
1133  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1134  [&](const auto &args) {
1135  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1136  << std::get<2>(args);
1137  });
1138 }
1139 
1140 /// Verifies CopyPrivate Clause
1141 static LogicalResult
1143  std::optional<ArrayAttr> copyprivateSyms) {
1144  size_t copyprivateSymsSize =
1145  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1146  if (copyprivateSymsSize != copyprivateVars.size())
1147  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1148  << copyprivateVars.size()
1149  << ") and functions (= " << copyprivateSymsSize
1150  << "), both must be equal";
1151  if (!copyprivateSyms.has_value())
1152  return success();
1153 
1154  for (auto copyprivateVarAndSym :
1155  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1156  auto symbolRef =
1157  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1158  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1159  funcOp;
1160  if (mlir::func::FuncOp mlirFuncOp =
1161  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1162  symbolRef))
1163  funcOp = mlirFuncOp;
1164  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1165  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1166  op, symbolRef))
1167  funcOp = llvmFuncOp;
1168 
1169  auto getNumArguments = [&] {
1170  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1171  };
1172 
1173  auto getArgumentType = [&](unsigned i) {
1174  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1175  *funcOp);
1176  };
1177 
1178  if (!funcOp)
1179  return op->emitOpError() << "expected symbol reference " << symbolRef
1180  << " to point to a copy function";
1181 
1182  if (getNumArguments() != 2)
1183  return op->emitOpError()
1184  << "expected copy function " << symbolRef << " to have 2 operands";
1185 
1186  Type argTy = getArgumentType(0);
1187  if (argTy != getArgumentType(1))
1188  return op->emitOpError() << "expected copy function " << symbolRef
1189  << " arguments to have the same type";
1190 
1191  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1192  if (argTy != varType)
1193  return op->emitOpError()
1194  << "expected copy function arguments' type (" << argTy
1195  << ") to be the same as copyprivate variable's type (" << varType
1196  << ")";
1197  }
1198 
1199  return success();
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // Parser, printer and verifier for DependVarList
1204 //===----------------------------------------------------------------------===//
1205 
1206 /// depend-entry-list ::= depend-entry
1207 /// | depend-entry-list `,` depend-entry
1208 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1209 static ParseResult
1212  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1214  if (failed(parser.parseCommaSeparatedList([&]() {
1215  StringRef keyword;
1216  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1217  parser.parseOperand(dependVars.emplace_back()) ||
1218  parser.parseColonType(dependTypes.emplace_back()))
1219  return failure();
1220  if (std::optional<ClauseTaskDepend> keywordDepend =
1221  (symbolizeClauseTaskDepend(keyword)))
1222  kindsVec.emplace_back(
1223  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1224  else
1225  return failure();
1226  return success();
1227  })))
1228  return failure();
1229  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1230  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1231  return success();
1232 }
1233 
1234 /// Print Depend clause
1236  OperandRange dependVars, TypeRange dependTypes,
1237  std::optional<ArrayAttr> dependKinds) {
1238 
1239  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1240  if (i != 0)
1241  p << ", ";
1242  p << stringifyClauseTaskDepend(
1243  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1244  .getValue())
1245  << " -> " << dependVars[i] << " : " << dependTypes[i];
1246  }
1247 }
1248 
1249 /// Verifies Depend clause
1250 static LogicalResult verifyDependVarList(Operation *op,
1251  std::optional<ArrayAttr> dependKinds,
1252  OperandRange dependVars) {
1253  if (!dependVars.empty()) {
1254  if (!dependKinds || dependKinds->size() != dependVars.size())
1255  return op->emitOpError() << "expected as many depend values"
1256  " as depend variables";
1257  } else {
1258  if (dependKinds && !dependKinds->empty())
1259  return op->emitOpError() << "unexpected depend values";
1260  return success();
1261  }
1262 
1263  return success();
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1268 //===----------------------------------------------------------------------===//
1269 
1270 /// Parses a Synchronization Hint clause. The value of hint is an integer
1271 /// which is a combination of different hints from `omp_sync_hint_t`.
1272 ///
1273 /// hint-clause = `hint` `(` hint-value `)`
1274 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1275  IntegerAttr &hintAttr) {
1276  StringRef hintKeyword;
1277  int64_t hint = 0;
1278  if (succeeded(parser.parseOptionalKeyword("none"))) {
1279  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1280  return success();
1281  }
1282  auto parseKeyword = [&]() -> ParseResult {
1283  if (failed(parser.parseKeyword(&hintKeyword)))
1284  return failure();
1285  if (hintKeyword == "uncontended")
1286  hint |= 1;
1287  else if (hintKeyword == "contended")
1288  hint |= 2;
1289  else if (hintKeyword == "nonspeculative")
1290  hint |= 4;
1291  else if (hintKeyword == "speculative")
1292  hint |= 8;
1293  else
1294  return parser.emitError(parser.getCurrentLocation())
1295  << hintKeyword << " is not a valid hint";
1296  return success();
1297  };
1298  if (parser.parseCommaSeparatedList(parseKeyword))
1299  return failure();
1300  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1301  return success();
1302 }
1303 
1304 /// Prints a Synchronization Hint clause
1306  IntegerAttr hintAttr) {
1307  int64_t hint = hintAttr.getInt();
1308 
1309  if (hint == 0) {
1310  p << "none";
1311  return;
1312  }
1313 
1314  // Helper function to get n-th bit from the right end of `value`
1315  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1316 
1317  bool uncontended = bitn(hint, 0);
1318  bool contended = bitn(hint, 1);
1319  bool nonspeculative = bitn(hint, 2);
1320  bool speculative = bitn(hint, 3);
1321 
1322  SmallVector<StringRef> hints;
1323  if (uncontended)
1324  hints.push_back("uncontended");
1325  if (contended)
1326  hints.push_back("contended");
1327  if (nonspeculative)
1328  hints.push_back("nonspeculative");
1329  if (speculative)
1330  hints.push_back("speculative");
1331 
1332  llvm::interleaveComma(hints, p);
1333 }
1334 
1335 /// Verifies a synchronization hint clause
1336 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1337 
1338  // Helper function to get n-th bit from the right end of `value`
1339  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1340 
1341  bool uncontended = bitn(hint, 0);
1342  bool contended = bitn(hint, 1);
1343  bool nonspeculative = bitn(hint, 2);
1344  bool speculative = bitn(hint, 3);
1345 
1346  if (uncontended && contended)
1347  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1348  "omp_sync_hint_contended cannot be combined";
1349  if (nonspeculative && speculative)
1350  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1351  "omp_sync_hint_speculative cannot be combined.";
1352  return success();
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // Parser, printer and verifier for Target
1357 //===----------------------------------------------------------------------===//
1358 
1359 // Helper function to get bitwise AND of `value` and 'flag'
1360 uint64_t mapTypeToBitFlag(uint64_t value,
1361  llvm::omp::OpenMPOffloadMappingFlags flag) {
1362  return value & llvm::to_underlying(flag);
1363 }
1364 
1365 /// Parses a map_entries map type from a string format back into its numeric
1366 /// value.
1367 ///
1368 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
1369 /// `to` | `from` | `delete` `)` )+ `)` )
1370 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1371  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1372  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1373 
1374  // This simply verifies the correct keyword is read in, the
1375  // keyword itself is stored inside of the operation
1376  auto parseTypeAndMod = [&]() -> ParseResult {
1377  StringRef mapTypeMod;
1378  if (parser.parseKeyword(&mapTypeMod))
1379  return failure();
1380 
1381  if (mapTypeMod == "always")
1382  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1383 
1384  if (mapTypeMod == "implicit")
1385  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1386 
1387  if (mapTypeMod == "close")
1388  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1389 
1390  if (mapTypeMod == "present")
1391  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1392 
1393  if (mapTypeMod == "to")
1394  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1395 
1396  if (mapTypeMod == "from")
1397  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1398 
1399  if (mapTypeMod == "tofrom")
1400  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1401  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1402 
1403  if (mapTypeMod == "delete")
1404  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1405 
1406  return success();
1407  };
1408 
1409  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1410  return failure();
1411 
1412  mapType = parser.getBuilder().getIntegerAttr(
1413  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1414  llvm::to_underlying(mapTypeBits));
1415 
1416  return success();
1417 }
1418 
1419 /// Prints a map_entries map type from its numeric value out into its string
1420 /// format.
1422  IntegerAttr mapType) {
1423  uint64_t mapTypeBits = mapType.getUInt();
1424 
1425  bool emitAllocRelease = true;
1427 
1428  // handling of always, close, present placed at the beginning of the string
1429  // to aid readability
1430  if (mapTypeToBitFlag(mapTypeBits,
1431  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1432  mapTypeStrs.push_back("always");
1433  if (mapTypeToBitFlag(mapTypeBits,
1434  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1435  mapTypeStrs.push_back("implicit");
1436  if (mapTypeToBitFlag(mapTypeBits,
1437  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1438  mapTypeStrs.push_back("close");
1439  if (mapTypeToBitFlag(mapTypeBits,
1440  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1441  mapTypeStrs.push_back("present");
1442 
1443  // special handling of to/from/tofrom/delete and release/alloc, release +
1444  // alloc are the abscense of one of the other flags, whereas tofrom requires
1445  // both the to and from flag to be set.
1446  bool to = mapTypeToBitFlag(mapTypeBits,
1447  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1448  bool from = mapTypeToBitFlag(
1449  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1450  if (to && from) {
1451  emitAllocRelease = false;
1452  mapTypeStrs.push_back("tofrom");
1453  } else if (from) {
1454  emitAllocRelease = false;
1455  mapTypeStrs.push_back("from");
1456  } else if (to) {
1457  emitAllocRelease = false;
1458  mapTypeStrs.push_back("to");
1459  }
1460  if (mapTypeToBitFlag(mapTypeBits,
1461  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1462  emitAllocRelease = false;
1463  mapTypeStrs.push_back("delete");
1464  }
1465  if (emitAllocRelease)
1466  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1467 
1468  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1469  p << mapTypeStrs[i];
1470  if (i + 1 < mapTypeStrs.size()) {
1471  p << ", ";
1472  }
1473  }
1474 }
1475 
1476 static ParseResult parseMembersIndex(OpAsmParser &parser,
1477  ArrayAttr &membersIdx) {
1478  SmallVector<Attribute> values, memberIdxs;
1479 
1480  auto parseIndices = [&]() -> ParseResult {
1481  int64_t value;
1482  if (parser.parseInteger(value))
1483  return failure();
1484  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1485  APInt(64, value, /*isSigned=*/false)));
1486  return success();
1487  };
1488 
1489  do {
1490  if (failed(parser.parseLSquare()))
1491  return failure();
1492 
1493  if (parser.parseCommaSeparatedList(parseIndices))
1494  return failure();
1495 
1496  if (failed(parser.parseRSquare()))
1497  return failure();
1498 
1499  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1500  values.clear();
1501  } while (succeeded(parser.parseOptionalComma()));
1502 
1503  if (!memberIdxs.empty())
1504  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1505 
1506  return success();
1507 }
1508 
1509 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1510  ArrayAttr membersIdx) {
1511  if (!membersIdx)
1512  return;
1513 
1514  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1515  p << "[";
1516  auto memberIdx = cast<ArrayAttr>(v);
1517  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1518  p << cast<IntegerAttr>(v2).getInt();
1519  });
1520  p << "]";
1521  });
1522 }
1523 
1525  VariableCaptureKindAttr mapCaptureType) {
1526  std::string typeCapStr;
1527  llvm::raw_string_ostream typeCap(typeCapStr);
1528  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1529  typeCap << "ByRef";
1530  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1531  typeCap << "ByCopy";
1532  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1533  typeCap << "VLAType";
1534  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1535  typeCap << "This";
1536  p << typeCapStr;
1537 }
1538 
1539 static ParseResult parseCaptureType(OpAsmParser &parser,
1540  VariableCaptureKindAttr &mapCaptureType) {
1541  StringRef mapCaptureKey;
1542  if (parser.parseKeyword(&mapCaptureKey))
1543  return failure();
1544 
1545  if (mapCaptureKey == "This")
1546  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1547  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1548  if (mapCaptureKey == "ByRef")
1549  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1550  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1551  if (mapCaptureKey == "ByCopy")
1552  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1553  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1554  if (mapCaptureKey == "VLAType")
1555  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1556  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1557 
1558  return success();
1559 }
1560 
1561 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1564 
1565  for (auto mapOp : mapVars) {
1566  if (!mapOp.getDefiningOp())
1567  emitError(op->getLoc(), "missing map operation");
1568 
1569  if (auto mapInfoOp =
1570  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1571  if (!mapInfoOp.getMapType().has_value())
1572  emitError(op->getLoc(), "missing map type for map operand");
1573 
1574  if (!mapInfoOp.getMapCaptureType().has_value())
1575  emitError(op->getLoc(), "missing map capture type for map operand");
1576 
1577  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1578 
1579  bool to = mapTypeToBitFlag(
1580  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1581  bool from = mapTypeToBitFlag(
1582  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1583  bool del = mapTypeToBitFlag(
1584  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1585 
1586  bool always = mapTypeToBitFlag(
1587  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1588  bool close = mapTypeToBitFlag(
1589  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1590  bool implicit = mapTypeToBitFlag(
1591  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1592 
1593  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1594  return emitError(op->getLoc(),
1595  "to, from, tofrom and alloc map types are permitted");
1596 
1597  if (isa<TargetEnterDataOp>(op) && (from || del))
1598  return emitError(op->getLoc(), "to and alloc map types are permitted");
1599 
1600  if (isa<TargetExitDataOp>(op) && to)
1601  return emitError(op->getLoc(),
1602  "from, release and delete map types are permitted");
1603 
1604  if (isa<TargetUpdateOp>(op)) {
1605  if (del) {
1606  return emitError(op->getLoc(),
1607  "at least one of to or from map types must be "
1608  "specified, other map types are not permitted");
1609  }
1610 
1611  if (!to && !from) {
1612  return emitError(op->getLoc(),
1613  "at least one of to or from map types must be "
1614  "specified, other map types are not permitted");
1615  }
1616 
1617  auto updateVar = mapInfoOp.getVarPtr();
1618 
1619  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1620  (from && updateToVars.contains(updateVar))) {
1621  return emitError(
1622  op->getLoc(),
1623  "either to or from map types can be specified, not both");
1624  }
1625 
1626  if (always || close || implicit) {
1627  return emitError(
1628  op->getLoc(),
1629  "present, mapper and iterator map type modifiers are permitted");
1630  }
1631 
1632  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1633  }
1634  } else {
1635  emitError(op->getLoc(), "map argument is not a map entry operation");
1636  }
1637  }
1638 
1639  return success();
1640 }
1641 
1642 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1643  std::optional<DenseI64ArrayAttr> privateMapIndices =
1644  targetOp.getPrivateMapsAttr();
1645 
1646  // None of the private operands are mapped.
1647  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1648  return success();
1649 
1650  OperandRange privateVars = targetOp.getPrivateVars();
1651 
1652  if (privateMapIndices.value().size() !=
1653  static_cast<int64_t>(privateVars.size()))
1654  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1655  "`private_maps` attribute mismatch");
1656 
1657  return success();
1658 }
1659 
1660 //===----------------------------------------------------------------------===//
1661 // TargetDataOp
1662 //===----------------------------------------------------------------------===//
1663 
1664 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1665  const TargetDataOperands &clauses) {
1666  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1667  clauses.mapVars, clauses.useDeviceAddrVars,
1668  clauses.useDevicePtrVars);
1669 }
1670 
1671 LogicalResult TargetDataOp::verify() {
1672  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1673  getUseDeviceAddrVars().empty()) {
1674  return ::emitError(this->getLoc(),
1675  "At least one of map, use_device_ptr_vars, or "
1676  "use_device_addr_vars operand must be present");
1677  }
1678  return verifyMapClause(*this, getMapVars());
1679 }
1680 
1681 //===----------------------------------------------------------------------===//
1682 // TargetEnterDataOp
1683 //===----------------------------------------------------------------------===//
1684 
1685 void TargetEnterDataOp::build(
1686  OpBuilder &builder, OperationState &state,
1687  const TargetEnterExitUpdateDataOperands &clauses) {
1688  MLIRContext *ctx = builder.getContext();
1689  TargetEnterDataOp::build(builder, state,
1690  makeArrayAttr(ctx, clauses.dependKinds),
1691  clauses.dependVars, clauses.device, clauses.ifExpr,
1692  clauses.mapVars, clauses.nowait);
1693 }
1694 
1695 LogicalResult TargetEnterDataOp::verify() {
1696  LogicalResult verifyDependVars =
1697  verifyDependVarList(*this, getDependKinds(), getDependVars());
1698  return failed(verifyDependVars) ? verifyDependVars
1699  : verifyMapClause(*this, getMapVars());
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // TargetExitDataOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1707  const TargetEnterExitUpdateDataOperands &clauses) {
1708  MLIRContext *ctx = builder.getContext();
1709  TargetExitDataOp::build(builder, state,
1710  makeArrayAttr(ctx, clauses.dependKinds),
1711  clauses.dependVars, clauses.device, clauses.ifExpr,
1712  clauses.mapVars, clauses.nowait);
1713 }
1714 
1715 LogicalResult TargetExitDataOp::verify() {
1716  LogicalResult verifyDependVars =
1717  verifyDependVarList(*this, getDependKinds(), getDependVars());
1718  return failed(verifyDependVars) ? verifyDependVars
1719  : verifyMapClause(*this, getMapVars());
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // TargetUpdateOp
1724 //===----------------------------------------------------------------------===//
1725 
1726 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1727  const TargetEnterExitUpdateDataOperands &clauses) {
1728  MLIRContext *ctx = builder.getContext();
1729  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1730  clauses.dependVars, clauses.device, clauses.ifExpr,
1731  clauses.mapVars, clauses.nowait);
1732 }
1733 
1734 LogicalResult TargetUpdateOp::verify() {
1735  LogicalResult verifyDependVars =
1736  verifyDependVarList(*this, getDependKinds(), getDependVars());
1737  return failed(verifyDependVars) ? verifyDependVars
1738  : verifyMapClause(*this, getMapVars());
1739 }
1740 
1741 //===----------------------------------------------------------------------===//
1742 // TargetOp
1743 //===----------------------------------------------------------------------===//
1744 
1745 void TargetOp::build(OpBuilder &builder, OperationState &state,
1746  const TargetOperands &clauses) {
1747  MLIRContext *ctx = builder.getContext();
1748  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1749  // inReductionByref, inReductionSyms.
1750  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1751  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1752  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1753  clauses.hostEvalVars, clauses.ifExpr,
1754  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1755  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1756  clauses.mapVars, clauses.nowait, clauses.privateVars,
1757  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1758  /*private_maps=*/nullptr);
1759 }
1760 
1761 LogicalResult TargetOp::verify() {
1762  LogicalResult verifyDependVars =
1763  verifyDependVarList(*this, getDependKinds(), getDependVars());
1764 
1765  if (failed(verifyDependVars))
1766  return verifyDependVars;
1767 
1768  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1769 
1770  if (failed(verifyMapVars))
1771  return verifyMapVars;
1772 
1773  return verifyPrivateVarsMapping(*this);
1774 }
1775 
1776 LogicalResult TargetOp::verifyRegions() {
1777  auto teamsOps = getOps<TeamsOp>();
1778  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1779  return emitError("target containing multiple 'omp.teams' nested ops");
1780 
1781  // Check that host_eval values are only used in legal ways.
1782  llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1783  for (Value hostEvalArg :
1784  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1785  for (Operation *user : hostEvalArg.getUsers()) {
1786  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1787  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1788  teamsOp.getNumTeamsUpper(),
1789  teamsOp.getThreadLimit()},
1790  hostEvalArg))
1791  continue;
1792 
1793  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1794  "and 'thread_limit' in 'omp.teams'";
1795  }
1796  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1797  if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1798  hostEvalArg == parallelOp.getNumThreads())
1799  continue;
1800 
1801  return emitOpError()
1802  << "host_eval argument only legal as 'num_threads' in "
1803  "'omp.parallel' when representing target SPMD";
1804  }
1805  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1806  if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1807  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1808  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1809  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1810  continue;
1811 
1812  return emitOpError() << "host_eval argument only legal as loop bounds "
1813  "and steps in 'omp.loop_nest' when "
1814  "representing target SPMD or Generic-SPMD";
1815  }
1816 
1817  return emitOpError() << "host_eval argument illegal use in '"
1818  << user->getName() << "' operation";
1819  }
1820  }
1821  return success();
1822 }
1823 
1824 /// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1825 /// effects, but don't include a memory write effect.
1827  if (!op)
1828  return false;
1829 
1830  bool isOmpDialect =
1831  op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1832  op->getDialect();
1833 
1834  if (isOmpDialect)
1835  return op->hasTrait<OpTrait::IsTerminator>();
1836 
1837  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1839  memOp.getEffects(effects);
1840  return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1841  return isa<MemoryEffects::Write>(effect.getEffect()) &&
1842  isa<SideEffects::AutomaticAllocationScopeResource>(
1843  effect.getResource());
1844  });
1845  }
1846  return true;
1847 }
1848 
1849 Operation *TargetOp::getInnermostCapturedOmpOp() {
1850  Dialect *ompDialect = (*this)->getDialect();
1851  Operation *capturedOp = nullptr;
1852  DominanceInfo domInfo;
1853 
1854  // Process in pre-order to check operations from outermost to innermost,
1855  // ensuring we only enter the region of an operation if it meets the criteria
1856  // for being captured. We stop the exploration of nested operations as soon as
1857  // we process a region holding no operations to be captured.
1858  walk<WalkOrder::PreOrder>([&](Operation *op) {
1859  if (op == *this)
1860  return WalkResult::advance();
1861 
1862  // Ignore operations of other dialects or omp operations with no regions,
1863  // because these will only be checked if they are siblings of an omp
1864  // operation that can potentially be captured.
1865  bool isOmpDialect = op->getDialect() == ompDialect;
1866  bool hasRegions = op->getNumRegions() > 0;
1867  if (!isOmpDialect || !hasRegions)
1868  return WalkResult::skip();
1869 
1870  // This operation cannot be captured if it can be executed more than once
1871  // (i.e. its block's successors can reach it) or if it's not guaranteed to
1872  // be executed before all exits of the region (i.e. it doesn't dominate all
1873  // blocks with no successors reachable from the entry block).
1874  Region *parentRegion = op->getParentRegion();
1875  Block *parentBlock = op->getBlock();
1876 
1877  for (Block *successor : parentBlock->getSuccessors())
1878  if (successor->isReachable(parentBlock))
1879  return WalkResult::interrupt();
1880 
1881  for (Block &block : *parentRegion)
1882  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1883  !domInfo.dominates(parentBlock, &block))
1884  return WalkResult::interrupt();
1885 
1886  // Don't capture this op if it has a not-allowed sibling, and stop recursing
1887  // into nested operations.
1888  for (Operation &sibling : op->getParentRegion()->getOps())
1889  if (&sibling != op && !siblingAllowedInCapture(&sibling))
1890  return WalkResult::interrupt();
1891 
1892  // Don't continue capturing nested operations if we reach an omp.loop_nest.
1893  // Otherwise, process the contents of this operation.
1894  capturedOp = op;
1895  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1896  : WalkResult::advance();
1897  });
1898 
1899  return capturedOp;
1900 }
1901 
1902 llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
1903  using namespace llvm::omp;
1904 
1905  // Make sure this region is capturing a loop. Otherwise, it's a generic
1906  // kernel.
1907  Operation *capturedOp = getInnermostCapturedOmpOp();
1908  if (!isa_and_present<LoopNestOp>(capturedOp))
1909  return OMP_TGT_EXEC_MODE_GENERIC;
1910 
1912  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
1913  assert(!wrappers.empty());
1914 
1915  // Ignore optional SIMD leaf construct.
1916  auto *innermostWrapper = wrappers.begin();
1917  if (isa<SimdOp>(innermostWrapper))
1918  innermostWrapper = std::next(innermostWrapper);
1919 
1920  long numWrappers = std::distance(innermostWrapper, wrappers.end());
1921 
1922  // Detect Generic-SPMD: target-teams-distribute[-simd].
1923  if (numWrappers == 1) {
1924  if (!isa<DistributeOp>(innermostWrapper))
1925  return OMP_TGT_EXEC_MODE_GENERIC;
1926 
1927  Operation *teamsOp = (*innermostWrapper)->getParentOp();
1928  if (!isa_and_present<TeamsOp>(teamsOp))
1929  return OMP_TGT_EXEC_MODE_GENERIC;
1930 
1931  if (teamsOp->getParentOp() == *this)
1932  return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1933  }
1934 
1935  // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1936  if (numWrappers == 2) {
1937  if (!isa<WsloopOp>(innermostWrapper))
1938  return OMP_TGT_EXEC_MODE_GENERIC;
1939 
1940  innermostWrapper = std::next(innermostWrapper);
1941  if (!isa<DistributeOp>(innermostWrapper))
1942  return OMP_TGT_EXEC_MODE_GENERIC;
1943 
1944  Operation *parallelOp = (*innermostWrapper)->getParentOp();
1945  if (!isa_and_present<ParallelOp>(parallelOp))
1946  return OMP_TGT_EXEC_MODE_GENERIC;
1947 
1948  Operation *teamsOp = parallelOp->getParentOp();
1949  if (!isa_and_present<TeamsOp>(teamsOp))
1950  return OMP_TGT_EXEC_MODE_GENERIC;
1951 
1952  if (teamsOp->getParentOp() == *this)
1953  return OMP_TGT_EXEC_MODE_SPMD;
1954  }
1955 
1956  return OMP_TGT_EXEC_MODE_GENERIC;
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // ParallelOp
1961 //===----------------------------------------------------------------------===//
1962 
1963 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1964  ArrayRef<NamedAttribute> attributes) {
1965  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
1966  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
1967  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
1968  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
1969  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
1970  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
1971  state.addAttributes(attributes);
1972 }
1973 
1974 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1975  const ParallelOperands &clauses) {
1976  MLIRContext *ctx = builder.getContext();
1977  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1978  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
1979  makeArrayAttr(ctx, clauses.privateSyms),
1980  clauses.procBindKind, clauses.reductionMod,
1981  clauses.reductionVars,
1982  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1983  makeArrayAttr(ctx, clauses.reductionSyms));
1984 }
1985 
1986 template <typename OpType>
1987 static LogicalResult verifyPrivateVarList(OpType &op) {
1988  auto privateVars = op.getPrivateVars();
1989  auto privateSyms = op.getPrivateSymsAttr();
1990 
1991  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
1992  return success();
1993 
1994  auto numPrivateVars = privateVars.size();
1995  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
1996 
1997  if (numPrivateVars != numPrivateSyms)
1998  return op.emitError() << "inconsistent number of private variables and "
1999  "privatizer op symbols, private vars: "
2000  << numPrivateVars
2001  << " vs. privatizer op symbols: " << numPrivateSyms;
2002 
2003  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2004  Type varType = std::get<0>(privateVarInfo).getType();
2005  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2006  PrivateClauseOp privatizerOp =
2007  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2008 
2009  if (privatizerOp == nullptr)
2010  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2011  << privateSym << "'";
2012 
2013  Type privatizerType = privatizerOp.getArgType();
2014 
2015  if (privatizerType && (varType != privatizerType))
2016  return op.emitError()
2017  << "type mismatch between a "
2018  << (privatizerOp.getDataSharingType() ==
2019  DataSharingClauseType::Private
2020  ? "private"
2021  : "firstprivate")
2022  << " variable and its privatizer op, var type: " << varType
2023  << " vs. privatizer op type: " << privatizerType;
2024  }
2025 
2026  return success();
2027 }
2028 
2029 LogicalResult ParallelOp::verify() {
2030  if (getAllocateVars().size() != getAllocatorVars().size())
2031  return emitError(
2032  "expected equal sizes for allocate and allocator variables");
2033 
2034  if (failed(verifyPrivateVarList(*this)))
2035  return failure();
2036 
2037  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2038  getReductionByref());
2039 }
2040 
2041 LogicalResult ParallelOp::verifyRegions() {
2042  auto distributeChildOps = getOps<DistributeOp>();
2043  if (!distributeChildOps.empty()) {
2044  if (!isComposite())
2045  return emitError()
2046  << "'omp.composite' attribute missing from composite operation";
2047 
2048  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2049  Operation &distributeOp = **distributeChildOps.begin();
2050  for (Operation &childOp : getOps()) {
2051  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2052  continue;
2053 
2054  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2055  return emitError() << "unexpected OpenMP operation inside of composite "
2056  "'omp.parallel'";
2057  }
2058  } else if (isComposite()) {
2059  return emitError()
2060  << "'omp.composite' attribute present in non-composite operation";
2061  }
2062  return success();
2063 }
2064 
2065 //===----------------------------------------------------------------------===//
2066 // TeamsOp
2067 //===----------------------------------------------------------------------===//
2068 
2070  while ((op = op->getParentOp()))
2071  if (isa<OpenMPDialect>(op->getDialect()))
2072  return false;
2073  return true;
2074 }
2075 
2076 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2077  const TeamsOperands &clauses) {
2078  MLIRContext *ctx = builder.getContext();
2079  // TODO Store clauses in op: privateVars, privateSyms.
2080  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2081  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2082  /*private_vars=*/{}, /*private_syms=*/nullptr,
2083  clauses.reductionMod, clauses.reductionVars,
2084  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2085  makeArrayAttr(ctx, clauses.reductionSyms),
2086  clauses.threadLimit);
2087 }
2088 
2089 LogicalResult TeamsOp::verify() {
2090  // Check parent region
2091  // TODO If nested inside of a target region, also check that it does not
2092  // contain any statements, declarations or directives other than this
2093  // omp.teams construct. The issue is how to support the initialization of
2094  // this operation's own arguments (allow SSA values across omp.target?).
2095  Operation *op = getOperation();
2096  if (!isa<TargetOp>(op->getParentOp()) &&
2098  return emitError("expected to be nested inside of omp.target or not nested "
2099  "in any OpenMP dialect operations");
2100 
2101  // Check for num_teams clause restrictions
2102  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2103  auto numTeamsUpperBound = getNumTeamsUpper();
2104  if (!numTeamsUpperBound)
2105  return emitError("expected num_teams upper bound to be defined if the "
2106  "lower bound is defined");
2107  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2108  return emitError(
2109  "expected num_teams upper bound and lower bound to be the same type");
2110  }
2111 
2112  // Check for allocate clause restrictions
2113  if (getAllocateVars().size() != getAllocatorVars().size())
2114  return emitError(
2115  "expected equal sizes for allocate and allocator variables");
2116 
2117  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2118  getReductionByref());
2119 }
2120 
2121 //===----------------------------------------------------------------------===//
2122 // SectionOp
2123 //===----------------------------------------------------------------------===//
2124 
2125 unsigned SectionOp::numPrivateBlockArgs() {
2126  return getParentOp().numPrivateBlockArgs();
2127 }
2128 
2129 unsigned SectionOp::numReductionBlockArgs() {
2130  return getParentOp().numReductionBlockArgs();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // SectionsOp
2135 //===----------------------------------------------------------------------===//
2136 
2137 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2138  const SectionsOperands &clauses) {
2139  MLIRContext *ctx = builder.getContext();
2140  // TODO Store clauses in op: privateVars, privateSyms.
2141  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2142  clauses.nowait, /*private_vars=*/{},
2143  /*private_syms=*/nullptr, clauses.reductionMod,
2144  clauses.reductionVars,
2145  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2146  makeArrayAttr(ctx, clauses.reductionSyms));
2147 }
2148 
2149 LogicalResult SectionsOp::verify() {
2150  if (getAllocateVars().size() != getAllocatorVars().size())
2151  return emitError(
2152  "expected equal sizes for allocate and allocator variables");
2153 
2154  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2155  getReductionByref());
2156 }
2157 
2158 LogicalResult SectionsOp::verifyRegions() {
2159  for (auto &inst : *getRegion().begin()) {
2160  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2161  return emitOpError()
2162  << "expected omp.section op or terminator op inside region";
2163  }
2164  }
2165 
2166  return success();
2167 }
2168 
2169 //===----------------------------------------------------------------------===//
2170 // SingleOp
2171 //===----------------------------------------------------------------------===//
2172 
2173 void SingleOp::build(OpBuilder &builder, OperationState &state,
2174  const SingleOperands &clauses) {
2175  MLIRContext *ctx = builder.getContext();
2176  // TODO Store clauses in op: privateVars, privateSyms.
2177  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2178  clauses.copyprivateVars,
2179  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2180  /*private_vars=*/{}, /*private_syms=*/nullptr);
2181 }
2182 
2183 LogicalResult SingleOp::verify() {
2184  // Check for allocate clause restrictions
2185  if (getAllocateVars().size() != getAllocatorVars().size())
2186  return emitError(
2187  "expected equal sizes for allocate and allocator variables");
2188 
2189  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2190  getCopyprivateSyms());
2191 }
2192 
2193 //===----------------------------------------------------------------------===//
2194 // WorkshareOp
2195 //===----------------------------------------------------------------------===//
2196 
2197 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2198  const WorkshareOperands &clauses) {
2199  WorkshareOp::build(builder, state, clauses.nowait);
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // WorkshareLoopWrapperOp
2204 //===----------------------------------------------------------------------===//
2205 
2206 LogicalResult WorkshareLoopWrapperOp::verify() {
2207  if (!(*this)->getParentOfType<WorkshareOp>())
2208  return emitError() << "must be nested in an omp.workshare";
2209  if (getNestedWrapper())
2210  return emitError() << "cannot be composite";
2211  return success();
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // LoopWrapperInterface
2216 //===----------------------------------------------------------------------===//
2217 
2218 LogicalResult LoopWrapperInterface::verifyImpl() {
2219  Operation *op = this->getOperation();
2220  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2222  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2223  "and `SingleBlock` traits";
2224 
2225  if (op->getNumRegions() != 1)
2226  return emitOpError() << "loop wrapper does not contain exactly one region";
2227 
2228  Region &region = op->getRegion(0);
2229  if (range_size(region.getOps()) != 1)
2230  return emitOpError()
2231  << "loop wrapper does not contain exactly one nested op";
2232 
2233  Operation &firstOp = *region.op_begin();
2234  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2235  return emitOpError() << "op nested in loop wrapper is not another loop "
2236  "wrapper or `omp.loop_nest`";
2237 
2238  return success();
2239 }
2240 
2241 //===----------------------------------------------------------------------===//
2242 // LoopOp
2243 //===----------------------------------------------------------------------===//
2244 
2245 void LoopOp::build(OpBuilder &builder, OperationState &state,
2246  const LoopOperands &clauses) {
2247  MLIRContext *ctx = builder.getContext();
2248 
2249  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2250  makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
2251  clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2252  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2253  makeArrayAttr(ctx, clauses.reductionSyms));
2254 }
2255 
2256 LogicalResult LoopOp::verify() {
2257  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2258  getReductionByref());
2259 }
2260 
2261 LogicalResult LoopOp::verifyRegions() {
2262  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2263  getNestedWrapper())
2264  return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
2265 
2266  return success();
2267 }
2268 
2269 //===----------------------------------------------------------------------===//
2270 // WsloopOp
2271 //===----------------------------------------------------------------------===//
2272 
2273 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2274  ArrayRef<NamedAttribute> attributes) {
2275  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2276  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2277  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2278  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2279  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2280  /*reduction_byref=*/nullptr,
2281  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2282  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2283  /*schedule_simd=*/false);
2284  state.addAttributes(attributes);
2285 }
2286 
2287 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2288  const WsloopOperands &clauses) {
2289  MLIRContext *ctx = builder.getContext();
2290  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
2291  // privateSyms.
2292  WsloopOp::build(builder, state,
2293  /*allocate_vars=*/{}, /*allocator_vars=*/{},
2294  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2295  clauses.order, clauses.orderMod, clauses.ordered,
2296  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2297  clauses.reductionMod, clauses.reductionVars,
2298  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2299  makeArrayAttr(ctx, clauses.reductionSyms),
2300  clauses.scheduleKind, clauses.scheduleChunk,
2301  clauses.scheduleMod, clauses.scheduleSimd);
2302 }
2303 
2304 LogicalResult WsloopOp::verify() {
2305  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2306  getReductionByref());
2307 }
2308 
2309 LogicalResult WsloopOp::verifyRegions() {
2310  bool isCompositeChildLeaf =
2311  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2312 
2313  if (LoopWrapperInterface nested = getNestedWrapper()) {
2314  if (!isComposite())
2315  return emitError()
2316  << "'omp.composite' attribute missing from composite wrapper";
2317 
2318  // Check for the allowed leaf constructs that may appear in a composite
2319  // construct directly after DO/FOR.
2320  if (!isa<SimdOp>(nested))
2321  return emitError() << "only supported nested wrapper is 'omp.simd'";
2322 
2323  } else if (isComposite() && !isCompositeChildLeaf) {
2324  return emitError()
2325  << "'omp.composite' attribute present in non-composite wrapper";
2326  } else if (!isComposite() && isCompositeChildLeaf) {
2327  return emitError()
2328  << "'omp.composite' attribute missing from composite wrapper";
2329  }
2330 
2331  return success();
2332 }
2333 
2334 //===----------------------------------------------------------------------===//
2335 // Simd construct [2.9.3.1]
2336 //===----------------------------------------------------------------------===//
2337 
2338 void SimdOp::build(OpBuilder &builder, OperationState &state,
2339  const SimdOperands &clauses) {
2340  MLIRContext *ctx = builder.getContext();
2341  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2342  // privateSyms.
2343  SimdOp::build(builder, state, clauses.alignedVars,
2344  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2345  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2346  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2347  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2348  clauses.reductionMod, clauses.reductionVars,
2349  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2350  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2351  clauses.simdlen);
2352 }
2353 
2354 LogicalResult SimdOp::verify() {
2355  if (getSimdlen().has_value() && getSafelen().has_value() &&
2356  getSimdlen().value() > getSafelen().value())
2357  return emitOpError()
2358  << "simdlen clause and safelen clause are both present, but the "
2359  "simdlen value is not less than or equal to safelen value";
2360 
2361  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2362  return failure();
2363 
2364  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2365  return failure();
2366 
2367  bool isCompositeChildLeaf =
2368  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2369 
2370  if (!isComposite() && isCompositeChildLeaf)
2371  return emitError()
2372  << "'omp.composite' attribute missing from composite wrapper";
2373 
2374  if (isComposite() && !isCompositeChildLeaf)
2375  return emitError()
2376  << "'omp.composite' attribute present in non-composite wrapper";
2377 
2378  return success();
2379 }
2380 
2381 LogicalResult SimdOp::verifyRegions() {
2382  if (getNestedWrapper())
2383  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2384 
2385  return success();
2386 }
2387 
2388 //===----------------------------------------------------------------------===//
2389 // Distribute construct [2.9.4.1]
2390 //===----------------------------------------------------------------------===//
2391 
2392 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2393  const DistributeOperands &clauses) {
2394  DistributeOp::build(builder, state, clauses.allocateVars,
2395  clauses.allocatorVars, clauses.distScheduleStatic,
2396  clauses.distScheduleChunkSize, clauses.order,
2397  clauses.orderMod, clauses.privateVars,
2398  makeArrayAttr(builder.getContext(), clauses.privateSyms));
2399 }
2400 
2401 LogicalResult DistributeOp::verify() {
2402  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2403  return emitOpError() << "chunk size set without "
2404  "dist_schedule_static being present";
2405 
2406  if (getAllocateVars().size() != getAllocatorVars().size())
2407  return emitError(
2408  "expected equal sizes for allocate and allocator variables");
2409 
2410  return success();
2411 }
2412 
2413 LogicalResult DistributeOp::verifyRegions() {
2414  if (LoopWrapperInterface nested = getNestedWrapper()) {
2415  if (!isComposite())
2416  return emitError()
2417  << "'omp.composite' attribute missing from composite wrapper";
2418  // Check for the allowed leaf constructs that may appear in a composite
2419  // construct directly after DISTRIBUTE.
2420  if (isa<WsloopOp>(nested)) {
2421  if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
2422  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2423  "when 'omp.parallel' is the direct parent";
2424  } else if (!isa<SimdOp>(nested))
2425  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2426  "'omp.wsloop'";
2427  } else if (isComposite()) {
2428  return emitError()
2429  << "'omp.composite' attribute present in non-composite wrapper";
2430  }
2431 
2432  return success();
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // DeclareReductionOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 LogicalResult DeclareReductionOp::verifyRegions() {
2440  if (!getAllocRegion().empty()) {
2441  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2442  if (yieldOp.getResults().size() != 1 ||
2443  yieldOp.getResults().getTypes()[0] != getType())
2444  return emitOpError() << "expects alloc region to yield a value "
2445  "of the reduction type";
2446  }
2447  }
2448 
2449  if (getInitializerRegion().empty())
2450  return emitOpError() << "expects non-empty initializer region";
2451  Block &initializerEntryBlock = getInitializerRegion().front();
2452 
2453  if (initializerEntryBlock.getNumArguments() == 1) {
2454  if (!getAllocRegion().empty())
2455  return emitOpError() << "expects two arguments to the initializer region "
2456  "when an allocation region is used";
2457  } else if (initializerEntryBlock.getNumArguments() == 2) {
2458  if (getAllocRegion().empty())
2459  return emitOpError() << "expects one argument to the initializer region "
2460  "when no allocation region is used";
2461  } else {
2462  return emitOpError()
2463  << "expects one or two arguments to the initializer region";
2464  }
2465 
2466  for (mlir::Value arg : initializerEntryBlock.getArguments())
2467  if (arg.getType() != getType())
2468  return emitOpError() << "expects initializer region argument to match "
2469  "the reduction type";
2470 
2471  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2472  if (yieldOp.getResults().size() != 1 ||
2473  yieldOp.getResults().getTypes()[0] != getType())
2474  return emitOpError() << "expects initializer region to yield a value "
2475  "of the reduction type";
2476  }
2477 
2478  if (getReductionRegion().empty())
2479  return emitOpError() << "expects non-empty reduction region";
2480  Block &reductionEntryBlock = getReductionRegion().front();
2481  if (reductionEntryBlock.getNumArguments() != 2 ||
2482  reductionEntryBlock.getArgumentTypes()[0] !=
2483  reductionEntryBlock.getArgumentTypes()[1] ||
2484  reductionEntryBlock.getArgumentTypes()[0] != getType())
2485  return emitOpError() << "expects reduction region with two arguments of "
2486  "the reduction type";
2487  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2488  if (yieldOp.getResults().size() != 1 ||
2489  yieldOp.getResults().getTypes()[0] != getType())
2490  return emitOpError() << "expects reduction region to yield a value "
2491  "of the reduction type";
2492  }
2493 
2494  if (!getAtomicReductionRegion().empty()) {
2495  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2496  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2497  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2498  atomicReductionEntryBlock.getArgumentTypes()[1])
2499  return emitOpError() << "expects atomic reduction region with two "
2500  "arguments of the same type";
2501  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2502  atomicReductionEntryBlock.getArgumentTypes()[0]);
2503  if (!ptrType ||
2504  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2505  return emitOpError() << "expects atomic reduction region arguments to "
2506  "be accumulators containing the reduction type";
2507  }
2508 
2509  if (getCleanupRegion().empty())
2510  return success();
2511  Block &cleanupEntryBlock = getCleanupRegion().front();
2512  if (cleanupEntryBlock.getNumArguments() != 1 ||
2513  cleanupEntryBlock.getArgument(0).getType() != getType())
2514  return emitOpError() << "expects cleanup region with one argument "
2515  "of the reduction type";
2516 
2517  return success();
2518 }
2519 
2520 //===----------------------------------------------------------------------===//
2521 // TaskOp
2522 //===----------------------------------------------------------------------===//
2523 
2524 void TaskOp::build(OpBuilder &builder, OperationState &state,
2525  const TaskOperands &clauses) {
2526  MLIRContext *ctx = builder.getContext();
2527  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2528  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2529  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2530  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2531  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2532  clauses.priority, /*private_vars=*/clauses.privateVars,
2533  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2534  clauses.untied, clauses.eventHandle);
2535 }
2536 
2537 LogicalResult TaskOp::verify() {
2538  LogicalResult verifyDependVars =
2539  verifyDependVarList(*this, getDependKinds(), getDependVars());
2540  return failed(verifyDependVars)
2541  ? verifyDependVars
2542  : verifyReductionVarList(*this, getInReductionSyms(),
2543  getInReductionVars(),
2544  getInReductionByref());
2545 }
2546 
2547 //===----------------------------------------------------------------------===//
2548 // TaskgroupOp
2549 //===----------------------------------------------------------------------===//
2550 
2551 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2552  const TaskgroupOperands &clauses) {
2553  MLIRContext *ctx = builder.getContext();
2554  TaskgroupOp::build(builder, state, clauses.allocateVars,
2555  clauses.allocatorVars, clauses.taskReductionVars,
2556  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2557  makeArrayAttr(ctx, clauses.taskReductionSyms));
2558 }
2559 
2560 LogicalResult TaskgroupOp::verify() {
2561  return verifyReductionVarList(*this, getTaskReductionSyms(),
2562  getTaskReductionVars(),
2563  getTaskReductionByref());
2564 }
2565 
2566 //===----------------------------------------------------------------------===//
2567 // TaskloopOp
2568 //===----------------------------------------------------------------------===//
2569 
2570 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2571  const TaskloopOperands &clauses) {
2572  MLIRContext *ctx = builder.getContext();
2573  // TODO Store clauses in op: privateVars, privateSyms.
2574  TaskloopOp::build(
2575  builder, state, clauses.allocateVars, clauses.allocatorVars,
2576  clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2577  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2578  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2579  clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2580  /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
2581  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2582  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2583 }
2584 
2585 SmallVector<Value> TaskloopOp::getAllReductionVars() {
2586  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
2587  getInReductionVars().end());
2588  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2589  getReductionVars().end());
2590  return allReductionNvars;
2591 }
2592 
2593 LogicalResult TaskloopOp::verify() {
2594  if (getAllocateVars().size() != getAllocatorVars().size())
2595  return emitError(
2596  "expected equal sizes for allocate and allocator variables");
2597  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2598  getReductionVars(), getReductionByref())) ||
2599  failed(verifyReductionVarList(*this, getInReductionSyms(),
2600  getInReductionVars(),
2601  getInReductionByref())))
2602  return failure();
2603 
2604  if (!getReductionVars().empty() && getNogroup())
2605  return emitError("if a reduction clause is present on the taskloop "
2606  "directive, the nogroup clause must not be specified");
2607  for (auto var : getReductionVars()) {
2608  if (llvm::is_contained(getInReductionVars(), var))
2609  return emitError("the same list item cannot appear in both a reduction "
2610  "and an in_reduction clause");
2611  }
2612 
2613  if (getGrainsize() && getNumTasks()) {
2614  return emitError(
2615  "the grainsize clause and num_tasks clause are mutually exclusive and "
2616  "may not appear on the same taskloop directive");
2617  }
2618 
2619  return success();
2620 }
2621 
2622 LogicalResult TaskloopOp::verifyRegions() {
2623  if (LoopWrapperInterface nested = getNestedWrapper()) {
2624  if (!isComposite())
2625  return emitError()
2626  << "'omp.composite' attribute missing from composite wrapper";
2627 
2628  // Check for the allowed leaf constructs that may appear in a composite
2629  // construct directly after TASKLOOP.
2630  if (!isa<SimdOp>(nested))
2631  return emitError() << "only supported nested wrapper is 'omp.simd'";
2632  } else if (isComposite()) {
2633  return emitError()
2634  << "'omp.composite' attribute present in non-composite wrapper";
2635  }
2636 
2637  return success();
2638 }
2639 
2640 //===----------------------------------------------------------------------===//
2641 // LoopNestOp
2642 //===----------------------------------------------------------------------===//
2643 
2644 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2645  // Parse an opening `(` followed by induction variables followed by `)`
2648  Type loopVarType;
2649  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2650  parser.parseColonType(loopVarType) ||
2651  // Parse loop bounds.
2652  parser.parseEqual() ||
2653  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2654  parser.parseKeyword("to") ||
2655  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2656  return failure();
2657 
2658  for (auto &iv : ivs)
2659  iv.type = loopVarType;
2660 
2661  // Parse "inclusive" flag.
2662  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2663  result.addAttribute("loop_inclusive",
2664  UnitAttr::get(parser.getBuilder().getContext()));
2665 
2666  // Parse step values.
2668  if (parser.parseKeyword("step") ||
2669  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2670  return failure();
2671 
2672  // Parse the body.
2673  Region *region = result.addRegion();
2674  if (parser.parseRegion(*region, ivs))
2675  return failure();
2676 
2677  // Resolve operands.
2678  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2679  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2680  parser.resolveOperands(steps, loopVarType, result.operands))
2681  return failure();
2682 
2683  // Parse the optional attribute list.
2684  return parser.parseOptionalAttrDict(result.attributes);
2685 }
2686 
2688  Region &region = getRegion();
2689  auto args = region.getArguments();
2690  p << " (" << args << ") : " << args[0].getType() << " = ("
2691  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2692  if (getLoopInclusive())
2693  p << "inclusive ";
2694  p << "step (" << getLoopSteps() << ") ";
2695  p.printRegion(region, /*printEntryBlockArgs=*/false);
2696 }
2697 
2698 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2699  const LoopNestOperands &clauses) {
2700  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2701  clauses.loopUpperBounds, clauses.loopSteps,
2702  clauses.loopInclusive);
2703 }
2704 
2705 LogicalResult LoopNestOp::verify() {
2706  if (getLoopLowerBounds().empty())
2707  return emitOpError() << "must represent at least one loop";
2708 
2709  if (getLoopLowerBounds().size() != getIVs().size())
2710  return emitOpError() << "number of range arguments and IVs do not match";
2711 
2712  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2713  if (lb.getType() != iv.getType())
2714  return emitOpError()
2715  << "range argument type does not match corresponding IV type";
2716  }
2717 
2718  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2719  return emitOpError() << "expects parent op to be a loop wrapper";
2720 
2721  return success();
2722 }
2723 
2724 void LoopNestOp::gatherWrappers(
2726  Operation *parent = (*this)->getParentOp();
2727  while (auto wrapper =
2728  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2729  wrappers.push_back(wrapper);
2730  parent = parent->getParentOp();
2731  }
2732 }
2733 
2734 //===----------------------------------------------------------------------===//
2735 // Critical construct (2.17.1)
2736 //===----------------------------------------------------------------------===//
2737 
2738 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2739  const CriticalDeclareOperands &clauses) {
2740  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2741 }
2742 
2743 LogicalResult CriticalDeclareOp::verify() {
2744  return verifySynchronizationHint(*this, getHint());
2745 }
2746 
2747 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2748  if (getNameAttr()) {
2749  SymbolRefAttr symbolRef = getNameAttr();
2750  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2751  *this, symbolRef);
2752  if (!decl) {
2753  return emitOpError() << "expected symbol reference " << symbolRef
2754  << " to point to a critical declaration";
2755  }
2756  }
2757 
2758  return success();
2759 }
2760 
2761 //===----------------------------------------------------------------------===//
2762 // Ordered construct
2763 //===----------------------------------------------------------------------===//
2764 
2765 static LogicalResult verifyOrderedParent(Operation &op) {
2766  bool hasRegion = op.getNumRegions() > 0;
2767  auto loopOp = op.getParentOfType<LoopNestOp>();
2768  if (!loopOp) {
2769  if (hasRegion)
2770  return success();
2771 
2772  // TODO: Consider if this needs to be the case only for the standalone
2773  // variant of the ordered construct.
2774  return op.emitOpError() << "must be nested inside of a loop";
2775  }
2776 
2777  Operation *wrapper = loopOp->getParentOp();
2778  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2779  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2780  if (!orderedAttr)
2781  return op.emitOpError() << "the enclosing worksharing-loop region must "
2782  "have an ordered clause";
2783 
2784  if (hasRegion && orderedAttr.getInt() != 0)
2785  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2786  "have a parameter present";
2787 
2788  if (!hasRegion && orderedAttr.getInt() == 0)
2789  return op.emitOpError() << "the enclosing loop's ordered clause must "
2790  "have a parameter present";
2791  } else if (!isa<SimdOp>(wrapper)) {
2792  return op.emitOpError() << "must be nested inside of a worksharing, simd "
2793  "or worksharing simd loop";
2794  }
2795  return success();
2796 }
2797 
2798 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2799  const OrderedOperands &clauses) {
2800  OrderedOp::build(builder, state, clauses.doacrossDependType,
2801  clauses.doacrossNumLoops, clauses.doacrossDependVars);
2802 }
2803 
2804 LogicalResult OrderedOp::verify() {
2805  if (failed(verifyOrderedParent(**this)))
2806  return failure();
2807 
2808  auto wrapper = (*this)->getParentOfType<WsloopOp>();
2809  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2810  return emitOpError() << "number of variables in depend clause does not "
2811  << "match number of iteration variables in the "
2812  << "doacross loop";
2813 
2814  return success();
2815 }
2816 
2817 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2818  const OrderedRegionOperands &clauses) {
2819  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2820 }
2821 
2822 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
2823 
2824 //===----------------------------------------------------------------------===//
2825 // TaskwaitOp
2826 //===----------------------------------------------------------------------===//
2827 
2828 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2829  const TaskwaitOperands &clauses) {
2830  // TODO Store clauses in op: dependKinds, dependVars, nowait.
2831  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
2832  /*depend_vars=*/{}, /*nowait=*/nullptr);
2833 }
2834 
2835 //===----------------------------------------------------------------------===//
2836 // Verifier for AtomicReadOp
2837 //===----------------------------------------------------------------------===//
2838 
2839 LogicalResult AtomicReadOp::verify() {
2840  if (verifyCommon().failed())
2841  return mlir::failure();
2842 
2843  if (auto mo = getMemoryOrder()) {
2844  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2845  *mo == ClauseMemoryOrderKind::Release) {
2846  return emitError(
2847  "memory-order must not be acq_rel or release for atomic reads");
2848  }
2849  }
2850  return verifySynchronizationHint(*this, getHint());
2851 }
2852 
2853 //===----------------------------------------------------------------------===//
2854 // Verifier for AtomicWriteOp
2855 //===----------------------------------------------------------------------===//
2856 
2857 LogicalResult AtomicWriteOp::verify() {
2858  if (verifyCommon().failed())
2859  return mlir::failure();
2860 
2861  if (auto mo = getMemoryOrder()) {
2862  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2863  *mo == ClauseMemoryOrderKind::Acquire) {
2864  return emitError(
2865  "memory-order must not be acq_rel or acquire for atomic writes");
2866  }
2867  }
2868  return verifySynchronizationHint(*this, getHint());
2869 }
2870 
2871 //===----------------------------------------------------------------------===//
2872 // Verifier for AtomicUpdateOp
2873 //===----------------------------------------------------------------------===//
2874 
2875 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2876  PatternRewriter &rewriter) {
2877  if (op.isNoOp()) {
2878  rewriter.eraseOp(op);
2879  return success();
2880  }
2881  if (Value writeVal = op.getWriteOpVal()) {
2882  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
2883  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2884  return success();
2885  }
2886  return failure();
2887 }
2888 
2889 LogicalResult AtomicUpdateOp::verify() {
2890  if (verifyCommon().failed())
2891  return mlir::failure();
2892 
2893  if (auto mo = getMemoryOrder()) {
2894  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2895  *mo == ClauseMemoryOrderKind::Acquire) {
2896  return emitError(
2897  "memory-order must not be acq_rel or acquire for atomic updates");
2898  }
2899  }
2900 
2901  return verifySynchronizationHint(*this, getHint());
2902 }
2903 
2904 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2905 
2906 //===----------------------------------------------------------------------===//
2907 // Verifier for AtomicCaptureOp
2908 //===----------------------------------------------------------------------===//
2909 
2910 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2911  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2912  return op;
2913  return dyn_cast<AtomicReadOp>(getSecondOp());
2914 }
2915 
2916 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2917  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2918  return op;
2919  return dyn_cast<AtomicWriteOp>(getSecondOp());
2920 }
2921 
2922 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2923  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2924  return op;
2925  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2926 }
2927 
2928 LogicalResult AtomicCaptureOp::verify() {
2929  return verifySynchronizationHint(*this, getHint());
2930 }
2931 
2932 LogicalResult AtomicCaptureOp::verifyRegions() {
2933  if (verifyRegionsCommon().failed())
2934  return mlir::failure();
2935 
2936  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
2937  return emitOpError(
2938  "operations inside capture region must not have hint clause");
2939 
2940  if (getFirstOp()->getAttr("memory_order") ||
2941  getSecondOp()->getAttr("memory_order"))
2942  return emitOpError(
2943  "operations inside capture region must not have memory_order clause");
2944  return success();
2945 }
2946 
2947 //===----------------------------------------------------------------------===//
2948 // CancelOp
2949 //===----------------------------------------------------------------------===//
2950 
2951 void CancelOp::build(OpBuilder &builder, OperationState &state,
2952  const CancelOperands &clauses) {
2953  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
2954 }
2955 
2956 LogicalResult CancelOp::verify() {
2957  ClauseCancellationConstructType cct = getCancelDirective();
2958  Operation *parentOp = (*this)->getParentOp();
2959 
2960  if (!parentOp) {
2961  return emitOpError() << "must be used within a region supporting "
2962  "cancel directive";
2963  }
2964 
2965  if ((cct == ClauseCancellationConstructType::Parallel) &&
2966  !isa<ParallelOp>(parentOp)) {
2967  return emitOpError() << "cancel parallel must appear "
2968  << "inside a parallel region";
2969  }
2970  if (cct == ClauseCancellationConstructType::Loop) {
2971  auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2972  auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2973  loopOp ? loopOp->getParentOp() : nullptr);
2974 
2975  if (!wsloopOp) {
2976  return emitOpError()
2977  << "cancel loop must appear inside a worksharing-loop region";
2978  }
2979  if (wsloopOp.getNowaitAttr()) {
2980  return emitError() << "A worksharing construct that is canceled "
2981  << "must not have a nowait clause";
2982  }
2983  if (wsloopOp.getOrderedAttr()) {
2984  return emitError() << "A worksharing construct that is canceled "
2985  << "must not have an ordered clause";
2986  }
2987 
2988  } else if (cct == ClauseCancellationConstructType::Sections) {
2989  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2990  return emitOpError() << "cancel sections must appear "
2991  << "inside a sections region";
2992  }
2993  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2994  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2995  return emitError() << "A sections construct that is canceled "
2996  << "must not have a nowait clause";
2997  }
2998  }
2999  // TODO : Add more when we support taskgroup.
3000  return success();
3001 }
3002 
3003 //===----------------------------------------------------------------------===//
3004 // CancellationPointOp
3005 //===----------------------------------------------------------------------===//
3006 
3007 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3008  const CancellationPointOperands &clauses) {
3009  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3010 }
3011 
3012 LogicalResult CancellationPointOp::verify() {
3013  ClauseCancellationConstructType cct = getCancelDirective();
3014  Operation *parentOp = (*this)->getParentOp();
3015 
3016  if (!parentOp) {
3017  return emitOpError() << "must be used within a region supporting "
3018  "cancellation point directive";
3019  }
3020 
3021  if ((cct == ClauseCancellationConstructType::Parallel) &&
3022  !(isa<ParallelOp>(parentOp))) {
3023  return emitOpError() << "cancellation point parallel must appear "
3024  << "inside a parallel region";
3025  }
3026  if ((cct == ClauseCancellationConstructType::Loop) &&
3027  (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
3028  return emitOpError() << "cancellation point loop must appear "
3029  << "inside a worksharing-loop region";
3030  }
3031  if ((cct == ClauseCancellationConstructType::Sections) &&
3032  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3033  return emitOpError() << "cancellation point sections must appear "
3034  << "inside a sections region";
3035  }
3036  // TODO : Add more when we support taskgroup.
3037  return success();
3038 }
3039 
3040 //===----------------------------------------------------------------------===//
3041 // MapBoundsOp
3042 //===----------------------------------------------------------------------===//
3043 
3044 LogicalResult MapBoundsOp::verify() {
3045  auto extent = getExtent();
3046  auto upperbound = getUpperBound();
3047  if (!extent && !upperbound)
3048  return emitError("expected extent or upperbound.");
3049  return success();
3050 }
3051 
3052 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3053  TypeRange /*result_types*/, StringAttr symName,
3054  TypeAttr type) {
3055  PrivateClauseOp::build(
3056  odsBuilder, odsState, symName, type,
3058  DataSharingClauseType::Private));
3059 }
3060 
3061 LogicalResult PrivateClauseOp::verifyRegions() {
3062  Type argType = getArgType();
3063  auto verifyTerminator = [&](Operation *terminator,
3064  bool yieldsValue) -> LogicalResult {
3065  if (!terminator->getBlock()->getSuccessors().empty())
3066  return success();
3067 
3068  if (!llvm::isa<YieldOp>(terminator))
3069  return mlir::emitError(terminator->getLoc())
3070  << "expected exit block terminator to be an `omp.yield` op.";
3071 
3072  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3073  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3074 
3075  if (!yieldsValue) {
3076  if (yieldedTypes.empty())
3077  return success();
3078 
3079  return mlir::emitError(terminator->getLoc())
3080  << "Did not expect any values to be yielded.";
3081  }
3082 
3083  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3084  return success();
3085 
3086  auto error = mlir::emitError(yieldOp.getLoc())
3087  << "Invalid yielded value. Expected type: " << argType
3088  << ", got: ";
3089 
3090  if (yieldedTypes.empty())
3091  error << "None";
3092  else
3093  error << yieldedTypes;
3094 
3095  return error;
3096  };
3097 
3098  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3099  StringRef regionName,
3100  bool yieldsValue) -> LogicalResult {
3101  assert(!region.empty());
3102 
3103  if (region.getNumArguments() != expectedNumArgs)
3104  return mlir::emitError(region.getLoc())
3105  << "`" << regionName << "`: "
3106  << "expected " << expectedNumArgs
3107  << " region arguments, got: " << region.getNumArguments();
3108 
3109  for (Block &block : region) {
3110  // MLIR will verify the absence of the terminator for us.
3111  if (!block.mightHaveTerminator())
3112  continue;
3113 
3114  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3115  return failure();
3116  }
3117 
3118  return success();
3119  };
3120 
3121  // Ensure all of the region arguments have the same type
3122  for (Region *region : getRegions())
3123  for (Type ty : region->getArgumentTypes())
3124  if (ty != argType)
3125  return emitError() << "Region argument type mismatch: got " << ty
3126  << " expected " << argType << ".";
3127 
3128  mlir::Region &initRegion = getInitRegion();
3129  if (!initRegion.empty() &&
3130  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3131  /*yieldsValue=*/true)))
3132  return failure();
3133 
3134  DataSharingClauseType dsType = getDataSharingType();
3135 
3136  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3137  return emitError("`private` clauses do not require a `copy` region.");
3138 
3139  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3140  return emitError(
3141  "`firstprivate` clauses require at least a `copy` region.");
3142 
3143  if (dsType == DataSharingClauseType::FirstPrivate &&
3144  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3145  /*yieldsValue=*/true)))
3146  return failure();
3147 
3148  if (!getDeallocRegion().empty() &&
3149  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3150  /*yieldsValue=*/false)))
3151  return failure();
3152 
3153  return success();
3154 }
3155 
3156 //===----------------------------------------------------------------------===//
3157 // Spec 5.2: Masked construct (10.5)
3158 //===----------------------------------------------------------------------===//
3159 
3160 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3161  const MaskedOperands &clauses) {
3162  MaskedOp::build(builder, state, clauses.filteredThreadId);
3163 }
3164 
3165 //===----------------------------------------------------------------------===//
3166 // Spec 5.2: Scan construct (5.6)
3167 //===----------------------------------------------------------------------===//
3168 
3169 void ScanOp::build(OpBuilder &builder, OperationState &state,
3170  const ScanOperands &clauses) {
3171  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3172 }
3173 
3174 LogicalResult ScanOp::verify() {
3175  if (hasExclusiveVars() == hasInclusiveVars())
3176  return emitError(
3177  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3178  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3179  if (parentWsLoopOp.getReductionModAttr() &&
3180  parentWsLoopOp.getReductionModAttr().getValue() ==
3181  ReductionModifier::inscan)
3182  return success();
3183  }
3184  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3185  if (parentSimdOp.getReductionModAttr() &&
3186  parentSimdOp.getReductionModAttr().getValue() ==
3187  ReductionModifier::inscan)
3188  return success();
3189  }
3190  return emitError("SCAN directive needs to be enclosed within a parent "
3191  "worksharing loop construct or SIMD construct with INSCAN "
3192  "reduction modifier");
3193 }
3194 
3195 #define GET_ATTRDEF_CLASSES
3196 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3197 
3198 #define GET_OP_CLASSES
3199 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3200 
3201 #define GET_TYPEDEF_CLASSES
3202 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:726
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr)
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static bool siblingAllowedInCapture(Operation *op)
Only allow OpenMP terminators and non-OpenMP ops that have known memory effects, but don't include a ...
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static void printHostEvalInReductionMapPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static ParseResult parseHostEvalInReductionMapPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:765
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:761
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:307
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:872
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.