MLIR  21.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35 #include <cstddef>
36 #include <iterator>
37 #include <optional>
38 #include <variant>
39 
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::omp;
47 
48 static ArrayAttr makeArrayAttr(MLIRContext *context,
50  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51 }
52 
53 static DenseBoolArrayAttr
55  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56 }
57 
58 namespace {
59 struct MemRefPointerLikeModel
60  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61  MemRefType> {
62  Type getElementType(Type pointer) const {
63  return llvm::cast<MemRefType>(pointer).getElementType();
64  }
65 };
66 
67 struct LLVMPointerPointerLikeModel
68  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69  LLVM::LLVMPointerType> {
70  Type getElementType(Type pointer) const { return Type(); }
71 };
72 } // namespace
73 
74 void OpenMPDialect::initialize() {
75  addOperations<
76 #define GET_OP_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78  >();
79  addAttributes<
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82  >();
83  addTypes<
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86  >();
87 
88  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89 
90  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92  *getContext());
93 
94  // Attach default offload module interface to module op to access
95  // offload functionality through
96  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97  *getContext());
98 
99  // Attach default declare target interfaces to operations which can be marked
100  // as declare target (Global Operations and Functions/Subroutines in dialects
101  // that Fortran (or other languages that lower to MLIR) translates too
102  mlir::LLVM::GlobalOp::attachInterface<
104  *getContext());
105  mlir::LLVM::LLVMFuncOp::attachInterface<
107  *getContext());
108  mlir::func::FuncOp::attachInterface<
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parser and printer for Allocate Clause
114 //===----------------------------------------------------------------------===//
115 
116 /// Parse an allocate clause with allocators and a list of operands with types.
117 ///
118 /// allocate-operand-list :: = allocate-operand |
119 /// allocator-operand `,` allocate-operand-list
120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121 /// ssa-id-and-type ::= ssa-id `:` type
122 static ParseResult parseAllocateAndAllocator(
123  OpAsmParser &parser,
125  SmallVectorImpl<Type> &allocateTypes,
127  SmallVectorImpl<Type> &allocatorTypes) {
128 
129  return parser.parseCommaSeparatedList([&]() {
131  Type type;
132  if (parser.parseOperand(operand) || parser.parseColonType(type))
133  return failure();
134  allocatorVars.push_back(operand);
135  allocatorTypes.push_back(type);
136  if (parser.parseArrow())
137  return failure();
138  if (parser.parseOperand(operand) || parser.parseColonType(type))
139  return failure();
140 
141  allocateVars.push_back(operand);
142  allocateTypes.push_back(type);
143  return success();
144  });
145 }
146 
147 /// Print allocate clause
149  OperandRange allocateVars,
150  TypeRange allocateTypes,
151  OperandRange allocatorVars,
152  TypeRange allocatorTypes) {
153  for (unsigned i = 0; i < allocateVars.size(); ++i) {
154  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157  }
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parser and printer for a clause attribute (StringEnumAttr)
162 //===----------------------------------------------------------------------===//
163 
164 template <typename ClauseAttr>
165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167  StringRef enumStr;
168  SMLoc loc = parser.getCurrentLocation();
169  if (parser.parseKeyword(&enumStr))
170  return failure();
171  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172  attr = ClauseAttr::get(parser.getContext(), *enumValue);
173  return success();
174  }
175  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
176 }
177 
178 template <typename ClauseAttr>
179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180  p << stringifyEnum(attr.getValue());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Parser and printer for Linear Clause
185 //===----------------------------------------------------------------------===//
186 
187 /// linear ::= `linear` `(` linear-list `)`
188 /// linear-list := linear-val | linear-val linear-list
189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
190 static ParseResult parseLinearClause(
191  OpAsmParser &parser,
193  SmallVectorImpl<Type> &linearTypes,
195  return parser.parseCommaSeparatedList([&]() {
197  Type type;
199  if (parser.parseOperand(var) || parser.parseEqual() ||
200  parser.parseOperand(stepVar) || parser.parseColonType(type))
201  return failure();
202 
203  linearVars.push_back(var);
204  linearTypes.push_back(type);
205  linearStepVars.push_back(stepVar);
206  return success();
207  });
208 }
209 
210 /// Print Linear Clause
212  ValueRange linearVars, TypeRange linearTypes,
213  ValueRange linearStepVars) {
214  size_t linearVarsSize = linearVars.size();
215  for (unsigned i = 0; i < linearVarsSize; ++i) {
216  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217  p << linearVars[i];
218  if (linearStepVars.size() > i)
219  p << " = " << linearStepVars[i];
220  p << " : " << linearVars[i].getType() << separator;
221  }
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Verifier for Nontemporal Clause
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verifyNontemporalClause(Operation *op,
229  OperandRange nontemporalVars) {
230 
231  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232  DenseSet<Value> nontemporalItems;
233  for (const auto &it : nontemporalVars)
234  if (!nontemporalItems.insert(it).second)
235  return op->emitOpError() << "nontemporal variable used more than once";
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Parser, verifier and printer for Aligned Clause
242 //===----------------------------------------------------------------------===//
243 static LogicalResult verifyAlignedClause(Operation *op,
244  std::optional<ArrayAttr> alignments,
245  OperandRange alignedVars) {
246  // Check if number of alignment values equals to number of aligned variables
247  if (!alignedVars.empty()) {
248  if (!alignments || alignments->size() != alignedVars.size())
249  return op->emitOpError()
250  << "expected as many alignment values as aligned variables";
251  } else {
252  if (alignments)
253  return op->emitOpError() << "unexpected alignment values attribute";
254  return success();
255  }
256 
257  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258  DenseSet<Value> alignedItems;
259  for (auto it : alignedVars)
260  if (!alignedItems.insert(it).second)
261  return op->emitOpError() << "aligned variable used more than once";
262 
263  if (!alignments)
264  return success();
265 
266  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267  for (unsigned i = 0; i < (*alignments).size(); ++i) {
268  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269  if (intAttr.getValue().sle(0))
270  return op->emitOpError() << "alignment should be greater than 0";
271  } else {
272  return op->emitOpError() << "expected integer alignment";
273  }
274  }
275 
276  return success();
277 }
278 
279 /// aligned ::= `aligned` `(` aligned-list `)`
280 /// aligned-list := aligned-val | aligned-val aligned-list
281 /// aligned-val := ssa-id-and-type `->` alignment
282 static ParseResult
285  SmallVectorImpl<Type> &alignedTypes,
286  ArrayAttr &alignmentsAttr) {
287  SmallVector<Attribute> alignmentVec;
288  if (failed(parser.parseCommaSeparatedList([&]() {
289  if (parser.parseOperand(alignedVars.emplace_back()) ||
290  parser.parseColonType(alignedTypes.emplace_back()) ||
291  parser.parseArrow() ||
292  parser.parseAttribute(alignmentVec.emplace_back())) {
293  return failure();
294  }
295  return success();
296  })))
297  return failure();
298  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300  return success();
301 }
302 
303 /// Print Aligned Clause
305  ValueRange alignedVars, TypeRange alignedTypes,
306  std::optional<ArrayAttr> alignments) {
307  for (unsigned i = 0; i < alignedVars.size(); ++i) {
308  if (i != 0)
309  p << ", ";
310  p << alignedVars[i] << " : " << alignedVars[i].getType();
311  p << " -> " << (*alignments)[i];
312  }
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Parser, printer and verifier for Schedule Clause
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult
321  SmallVectorImpl<SmallString<12>> &modifiers) {
322  if (modifiers.size() > 2)
323  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
324  for (const auto &mod : modifiers) {
325  // Translate the string. If it has no value, then it was not a valid
326  // modifier!
327  auto symbol = symbolizeScheduleModifier(mod);
328  if (!symbol)
329  return parser.emitError(parser.getNameLoc())
330  << " unknown modifier type: " << mod;
331  }
332 
333  // If we have one modifier that is "simd", then stick a "none" modiifer in
334  // index 0.
335  if (modifiers.size() == 1) {
336  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337  modifiers.push_back(modifiers[0]);
338  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339  }
340  } else if (modifiers.size() == 2) {
341  // If there are two modifier:
342  // First modifier should not be simd, second one should be simd
343  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345  return parser.emitError(parser.getNameLoc())
346  << " incorrect modifier order";
347  }
348  return success();
349 }
350 
351 /// schedule ::= `schedule` `(` sched-list `)`
352 /// sched-list ::= sched-val | sched-val sched-list |
353 /// sched-val `,` sched-modifier
354 /// sched-val ::= sched-with-chunk | sched-wo-chunk
355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357 /// sched-wo-chunk ::= `auto` | `runtime`
358 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
359 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
360 static ParseResult
361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364  Type &chunkType) {
365  StringRef keyword;
366  if (parser.parseKeyword(&keyword))
367  return failure();
368  std::optional<mlir::omp::ClauseScheduleKind> schedule =
369  symbolizeClauseScheduleKind(keyword);
370  if (!schedule)
371  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374  switch (*schedule) {
375  case ClauseScheduleKind::Static:
376  case ClauseScheduleKind::Dynamic:
377  case ClauseScheduleKind::Guided:
378  if (succeeded(parser.parseOptionalEqual())) {
379  chunkSize = OpAsmParser::UnresolvedOperand{};
380  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381  return failure();
382  } else {
383  chunkSize = std::nullopt;
384  }
385  break;
386  case ClauseScheduleKind::Auto:
388  chunkSize = std::nullopt;
389  }
390 
391  // If there is a comma, we have one or more modifiers..
392  SmallVector<SmallString<12>> modifiers;
393  while (succeeded(parser.parseOptionalComma())) {
394  StringRef mod;
395  if (parser.parseKeyword(&mod))
396  return failure();
397  modifiers.push_back(mod);
398  }
399 
400  if (verifyScheduleModifiers(parser, modifiers))
401  return failure();
402 
403  if (!modifiers.empty()) {
404  SMLoc loc = parser.getCurrentLocation();
405  if (std::optional<ScheduleModifier> mod =
406  symbolizeScheduleModifier(modifiers[0])) {
407  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408  } else {
409  return parser.emitError(loc, "invalid schedule modifier");
410  }
411  // Only SIMD attribute is allowed here!
412  if (modifiers.size() > 1) {
413  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415  }
416  }
417 
418  return success();
419 }
420 
421 /// Print schedule clause
423  ClauseScheduleKindAttr scheduleKind,
424  ScheduleModifierAttr scheduleMod,
425  UnitAttr scheduleSimd, Value scheduleChunk,
426  Type scheduleChunkType) {
427  p << stringifyClauseScheduleKind(scheduleKind.getValue());
428  if (scheduleChunk)
429  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430  if (scheduleMod)
431  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432  if (scheduleSimd)
433  p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser and printer for Order Clause
438 //===----------------------------------------------------------------------===//
439 
440 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
441 // order-modifier ::= reproducible | unconstrained
442 static ParseResult parseOrderClause(OpAsmParser &parser,
443  ClauseOrderKindAttr &order,
444  OrderModifierAttr &orderMod) {
445  StringRef enumStr;
446  SMLoc loc = parser.getCurrentLocation();
447  if (parser.parseKeyword(&enumStr))
448  return failure();
449  if (std::optional<OrderModifier> enumValue =
450  symbolizeOrderModifier(enumStr)) {
451  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452  if (parser.parseOptionalColon())
453  return failure();
454  loc = parser.getCurrentLocation();
455  if (parser.parseKeyword(&enumStr))
456  return failure();
457  }
458  if (std::optional<ClauseOrderKind> enumValue =
459  symbolizeClauseOrderKind(enumStr)) {
460  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461  return success();
462  }
463  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
464 }
465 
467  ClauseOrderKindAttr order,
468  OrderModifierAttr orderMod) {
469  if (orderMod)
470  p << stringifyOrderModifier(orderMod.getValue()) << ":";
471  if (order)
472  p << stringifyClauseOrderKind(order.getValue());
473 }
474 
475 template <typename ClauseTypeAttr, typename ClauseType>
476 static ParseResult
477 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478  std::optional<OpAsmParser::UnresolvedOperand> &operand,
479  Type &operandType,
480  std::optional<ClauseType> (*symbolizeClause)(StringRef),
481  StringRef clauseName) {
482  StringRef enumStr;
483  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
484  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
486  if (parser.parseComma())
487  return failure();
488  } else {
489  return parser.emitError(parser.getCurrentLocation())
490  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
491  ;
492  }
493  }
494 
496  if (succeeded(parser.parseOperand(var))) {
497  operand = var;
498  } else {
499  return parser.emitError(parser.getCurrentLocation())
500  << "expected " << clauseName << " operand";
501  }
502 
503  if (operand.has_value()) {
504  if (parser.parseColonType(operandType))
505  return failure();
506  }
507 
508  return success();
509 }
510 
511 template <typename ClauseTypeAttr, typename ClauseType>
512 static void
514  ClauseTypeAttr prescriptiveness, Value operand,
515  mlir::Type operandType,
516  StringRef (*stringifyClauseType)(ClauseType)) {
517 
518  if (prescriptiveness)
519  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
520 
521  if (operand)
522  p << operand << ": " << operandType;
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // Parser and printer for grainsize Clause
527 //===----------------------------------------------------------------------===//
528 
529 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530 static ParseResult
531 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533  Type &grainsizeType) {
534  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535  parser, grainsizeMod, grainsize, grainsizeType,
536  &symbolizeClauseGrainsizeType, "grainsize");
537 }
538 
540  ClauseGrainsizeTypeAttr grainsizeMod,
541  Value grainsize, mlir::Type grainsizeType) {
542  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543  p, op, grainsizeMod, grainsize, grainsizeType,
544  &stringifyClauseGrainsizeType);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Parser and printer for num_tasks Clause
549 //===----------------------------------------------------------------------===//
550 
551 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552 static ParseResult
553 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555  Type &numTasksType) {
556  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558  "num_tasks");
559 }
560 
562  ClauseNumTasksTypeAttr numTasksMod,
563  Value numTasks, mlir::Type numTasksType) {
564  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Parsers for operations including clauses that define entry block arguments.
570 //===----------------------------------------------------------------------===//
571 
572 namespace {
573 struct MapParseArgs {
575  SmallVectorImpl<Type> &types;
577  SmallVectorImpl<Type> &types)
578  : vars(vars), types(types) {}
579 };
580 struct PrivateParseArgs {
583  ArrayAttr &syms;
584  DenseI64ArrayAttr *mapIndices;
586  SmallVectorImpl<Type> &types, ArrayAttr &syms,
587  DenseI64ArrayAttr *mapIndices = nullptr)
588  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
589 };
590 
591 struct ReductionParseArgs {
593  SmallVectorImpl<Type> &types;
594  DenseBoolArrayAttr &byref;
595  ArrayAttr &syms;
596  ReductionModifierAttr *modifier;
597  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
599  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
600  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
601 };
602 
603 struct AllRegionParseArgs {
604  std::optional<MapParseArgs> hasDeviceAddrArgs;
605  std::optional<MapParseArgs> hostEvalArgs;
606  std::optional<ReductionParseArgs> inReductionArgs;
607  std::optional<MapParseArgs> mapArgs;
608  std::optional<PrivateParseArgs> privateArgs;
609  std::optional<ReductionParseArgs> reductionArgs;
610  std::optional<ReductionParseArgs> taskReductionArgs;
611  std::optional<MapParseArgs> useDeviceAddrArgs;
612  std::optional<MapParseArgs> useDevicePtrArgs;
613 };
614 } // namespace
615 
616 static ParseResult parseClauseWithRegionArgs(
617  OpAsmParser &parser,
619  SmallVectorImpl<Type> &types,
620  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
621  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
622  DenseBoolArrayAttr *byref = nullptr,
623  ReductionModifierAttr *modifier = nullptr) {
624  SmallVector<SymbolRefAttr> symbolVec;
625  SmallVector<int64_t> mapIndicesVec;
626  SmallVector<bool> isByRefVec;
627  unsigned regionArgOffset = regionPrivateArgs.size();
628 
629  if (parser.parseLParen())
630  return failure();
631 
632  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
633  StringRef enumStr;
634  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
635  parser.parseComma())
636  return failure();
637  std::optional<ReductionModifier> enumValue =
638  symbolizeReductionModifier(enumStr);
639  if (!enumValue.has_value())
640  return failure();
641  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
642  if (!*modifier)
643  return failure();
644  }
645 
646  if (parser.parseCommaSeparatedList([&]() {
647  if (byref)
648  isByRefVec.push_back(
649  parser.parseOptionalKeyword("byref").succeeded());
650 
651  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
652  return failure();
653 
654  if (parser.parseOperand(operands.emplace_back()) ||
655  parser.parseArrow() ||
656  parser.parseArgument(regionPrivateArgs.emplace_back()))
657  return failure();
658 
659  if (mapIndices) {
660  if (parser.parseOptionalLSquare().succeeded()) {
661  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
662  parser.parseInteger(mapIndicesVec.emplace_back()) ||
663  parser.parseRSquare())
664  return failure();
665  } else
666  mapIndicesVec.push_back(-1);
667  }
668 
669  return success();
670  }))
671  return failure();
672 
673  if (parser.parseColon())
674  return failure();
675 
676  if (parser.parseCommaSeparatedList([&]() {
677  if (parser.parseType(types.emplace_back()))
678  return failure();
679 
680  return success();
681  }))
682  return failure();
683 
684  if (operands.size() != types.size())
685  return failure();
686 
687  if (parser.parseRParen())
688  return failure();
689 
690  auto *argsBegin = regionPrivateArgs.begin();
691  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
692  argsBegin + regionArgOffset + types.size());
693  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
694  prv.type = type;
695  }
696 
697  if (symbols) {
698  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
699  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
700  }
701 
702  if (!mapIndicesVec.empty())
703  *mapIndices =
704  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
705 
706  if (byref)
707  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
708 
709  return success();
710 }
711 
712 static ParseResult parseBlockArgClause(
713  OpAsmParser &parser,
715  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
716  if (succeeded(parser.parseOptionalKeyword(keyword))) {
717  if (!mapArgs)
718  return failure();
719 
720  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
721  entryBlockArgs)))
722  return failure();
723  }
724  return success();
725 }
726 
727 static ParseResult parseBlockArgClause(
728  OpAsmParser &parser,
730  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
731  if (succeeded(parser.parseOptionalKeyword(keyword))) {
732  if (!privateArgs)
733  return failure();
734 
735  if (failed(parseClauseWithRegionArgs(
736  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
737  &privateArgs->syms, privateArgs->mapIndices)))
738  return failure();
739  }
740  return success();
741 }
742 
743 static ParseResult parseBlockArgClause(
744  OpAsmParser &parser,
746  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
747  if (succeeded(parser.parseOptionalKeyword(keyword))) {
748  if (!reductionArgs)
749  return failure();
750  if (failed(parseClauseWithRegionArgs(
751  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
752  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
753  reductionArgs->modifier)))
754  return failure();
755  }
756  return success();
757 }
758 
759 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
760  AllRegionParseArgs args) {
762 
763  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
764  args.hasDeviceAddrArgs)))
765  return parser.emitError(parser.getCurrentLocation())
766  << "invalid `has_device_addr` format";
767 
768  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
769  args.hostEvalArgs)))
770  return parser.emitError(parser.getCurrentLocation())
771  << "invalid `host_eval` format";
772 
773  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
774  args.inReductionArgs)))
775  return parser.emitError(parser.getCurrentLocation())
776  << "invalid `in_reduction` format";
777 
778  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
779  args.mapArgs)))
780  return parser.emitError(parser.getCurrentLocation())
781  << "invalid `map_entries` format";
782 
783  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
784  args.privateArgs)))
785  return parser.emitError(parser.getCurrentLocation())
786  << "invalid `private` format";
787 
788  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
789  args.reductionArgs)))
790  return parser.emitError(parser.getCurrentLocation())
791  << "invalid `reduction` format";
792 
793  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
794  args.taskReductionArgs)))
795  return parser.emitError(parser.getCurrentLocation())
796  << "invalid `task_reduction` format";
797 
798  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
799  args.useDeviceAddrArgs)))
800  return parser.emitError(parser.getCurrentLocation())
801  << "invalid `use_device_addr` format";
802 
803  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
804  args.useDevicePtrArgs)))
805  return parser.emitError(parser.getCurrentLocation())
806  << "invalid `use_device_addr` format";
807 
808  return parser.parseRegion(region, entryBlockArgs);
809 }
810 
811 // These parseXyz functions correspond to the custom<Xyz> definitions
812 // in the .td file(s).
813 static ParseResult parseTargetOpRegion(
814  OpAsmParser &parser, Region &region,
816  SmallVectorImpl<Type> &hasDeviceAddrTypes,
818  SmallVectorImpl<Type> &hostEvalTypes,
820  SmallVectorImpl<Type> &inReductionTypes,
821  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
823  SmallVectorImpl<Type> &mapTypes,
825  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
826  DenseI64ArrayAttr &privateMaps) {
827  AllRegionParseArgs args;
828  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
829  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
830  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
831  inReductionByref, inReductionSyms);
832  args.mapArgs.emplace(mapVars, mapTypes);
833  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
834  &privateMaps);
835  return parseBlockArgRegion(parser, region, args);
836 }
837 
838 static ParseResult parseInReductionPrivateRegion(
839  OpAsmParser &parser, Region &region,
841  SmallVectorImpl<Type> &inReductionTypes,
842  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
844  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
845  AllRegionParseArgs args;
846  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
847  inReductionByref, inReductionSyms);
848  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
849  return parseBlockArgRegion(parser, region, args);
850 }
851 
853  OpAsmParser &parser, Region &region,
855  SmallVectorImpl<Type> &inReductionTypes,
856  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
858  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
859  ReductionModifierAttr &reductionMod,
861  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
862  ArrayAttr &reductionSyms) {
863  AllRegionParseArgs args;
864  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
865  inReductionByref, inReductionSyms);
866  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
867  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
868  reductionSyms, &reductionMod);
869  return parseBlockArgRegion(parser, region, args);
870 }
871 
872 static ParseResult parsePrivateRegion(
873  OpAsmParser &parser, Region &region,
875  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
876  AllRegionParseArgs args;
877  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
878  return parseBlockArgRegion(parser, region, args);
879 }
880 
881 static ParseResult parsePrivateReductionRegion(
882  OpAsmParser &parser, Region &region,
884  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
885  ReductionModifierAttr &reductionMod,
887  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
888  ArrayAttr &reductionSyms) {
889  AllRegionParseArgs args;
890  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
891  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
892  reductionSyms, &reductionMod);
893  return parseBlockArgRegion(parser, region, args);
894 }
895 
896 static ParseResult parseTaskReductionRegion(
897  OpAsmParser &parser, Region &region,
899  SmallVectorImpl<Type> &taskReductionTypes,
900  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
901  AllRegionParseArgs args;
902  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
903  taskReductionByref, taskReductionSyms);
904  return parseBlockArgRegion(parser, region, args);
905 }
906 
908  OpAsmParser &parser, Region &region,
910  SmallVectorImpl<Type> &useDeviceAddrTypes,
912  SmallVectorImpl<Type> &useDevicePtrTypes) {
913  AllRegionParseArgs args;
914  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
915  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
916  return parseBlockArgRegion(parser, region, args);
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Printers for operations including clauses that define entry block arguments.
921 //===----------------------------------------------------------------------===//
922 
923 namespace {
924 struct MapPrintArgs {
925  ValueRange vars;
926  TypeRange types;
927  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
928 };
929 struct PrivatePrintArgs {
930  ValueRange vars;
931  TypeRange types;
932  ArrayAttr syms;
933  DenseI64ArrayAttr mapIndices;
934  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
935  DenseI64ArrayAttr mapIndices)
936  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
937 };
938 struct ReductionPrintArgs {
939  ValueRange vars;
940  TypeRange types;
941  DenseBoolArrayAttr byref;
942  ArrayAttr syms;
943  ReductionModifierAttr modifier;
944  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
945  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
946  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
947 };
948 struct AllRegionPrintArgs {
949  std::optional<MapPrintArgs> hasDeviceAddrArgs;
950  std::optional<MapPrintArgs> hostEvalArgs;
951  std::optional<ReductionPrintArgs> inReductionArgs;
952  std::optional<MapPrintArgs> mapArgs;
953  std::optional<PrivatePrintArgs> privateArgs;
954  std::optional<ReductionPrintArgs> reductionArgs;
955  std::optional<ReductionPrintArgs> taskReductionArgs;
956  std::optional<MapPrintArgs> useDeviceAddrArgs;
957  std::optional<MapPrintArgs> useDevicePtrArgs;
958 };
959 } // namespace
960 
962  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
963  ValueRange argsSubrange, ValueRange operands, TypeRange types,
964  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
965  DenseBoolArrayAttr byref = nullptr,
966  ReductionModifierAttr modifier = nullptr) {
967  if (argsSubrange.empty())
968  return;
969 
970  p << clauseName << "(";
971 
972  if (modifier)
973  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
974 
975  if (!symbols) {
976  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
977  symbols = ArrayAttr::get(ctx, values);
978  }
979 
980  if (!mapIndices) {
981  llvm::SmallVector<int64_t> values(operands.size(), -1);
982  mapIndices = DenseI64ArrayAttr::get(ctx, values);
983  }
984 
985  if (!byref) {
986  mlir::SmallVector<bool> values(operands.size(), false);
987  byref = DenseBoolArrayAttr::get(ctx, values);
988  }
989 
990  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
991  mapIndices.asArrayRef(),
992  byref.asArrayRef()),
993  p, [&p](auto t) {
994  auto [op, arg, sym, map, isByRef] = t;
995  if (isByRef)
996  p << "byref ";
997  if (sym)
998  p << sym << " ";
999 
1000  p << op << " -> " << arg;
1001 
1002  if (map != -1)
1003  p << " [map_idx=" << map << "]";
1004  });
1005  p << " : ";
1006  llvm::interleaveComma(types, p);
1007  p << ") ";
1008 }
1009 
1011  StringRef clauseName, ValueRange argsSubrange,
1012  std::optional<MapPrintArgs> mapArgs) {
1013  if (mapArgs)
1014  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1015  mapArgs->types);
1016 }
1017 
1019  StringRef clauseName, ValueRange argsSubrange,
1020  std::optional<PrivatePrintArgs> privateArgs) {
1021  if (privateArgs)
1022  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1023  privateArgs->vars, privateArgs->types,
1024  privateArgs->syms, privateArgs->mapIndices);
1025 }
1026 
1027 static void
1028 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1029  ValueRange argsSubrange,
1030  std::optional<ReductionPrintArgs> reductionArgs) {
1031  if (reductionArgs)
1032  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1033  reductionArgs->vars, reductionArgs->types,
1034  reductionArgs->syms, /*mapIndices=*/nullptr,
1035  reductionArgs->byref, reductionArgs->modifier);
1036 }
1037 
1038 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1039  const AllRegionPrintArgs &args) {
1040  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1041  MLIRContext *ctx = op->getContext();
1042 
1043  printBlockArgClause(p, ctx, "has_device_addr",
1044  iface.getHasDeviceAddrBlockArgs(),
1045  args.hasDeviceAddrArgs);
1046  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1047  args.hostEvalArgs);
1048  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1049  args.inReductionArgs);
1050  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1051  args.mapArgs);
1052  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1053  args.privateArgs);
1054  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1055  args.reductionArgs);
1056  printBlockArgClause(p, ctx, "task_reduction",
1057  iface.getTaskReductionBlockArgs(),
1058  args.taskReductionArgs);
1059  printBlockArgClause(p, ctx, "use_device_addr",
1060  iface.getUseDeviceAddrBlockArgs(),
1061  args.useDeviceAddrArgs);
1062  printBlockArgClause(p, ctx, "use_device_ptr",
1063  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1064 
1065  p.printRegion(region, /*printEntryBlockArgs=*/false);
1066 }
1067 
1068 // These parseXyz functions correspond to the custom<Xyz> definitions
1069 // in the .td file(s).
1070 static void
1072  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1073  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1074  ValueRange inReductionVars, TypeRange inReductionTypes,
1075  DenseBoolArrayAttr inReductionByref,
1076  ArrayAttr inReductionSyms, ValueRange mapVars,
1077  TypeRange mapTypes, ValueRange privateVars,
1078  TypeRange privateTypes, ArrayAttr privateSyms,
1079  DenseI64ArrayAttr privateMaps) {
1080  AllRegionPrintArgs args;
1081  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1082  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1083  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1084  inReductionByref, inReductionSyms);
1085  args.mapArgs.emplace(mapVars, mapTypes);
1086  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
1087  printBlockArgRegion(p, op, region, args);
1088 }
1089 
1091  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1092  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1093  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1094  ArrayAttr privateSyms) {
1095  AllRegionPrintArgs args;
1096  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1097  inReductionByref, inReductionSyms);
1098  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1099  /*mapIndices=*/nullptr);
1100  printBlockArgRegion(p, op, region, args);
1101 }
1102 
1104  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1105  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1106  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1107  ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
1108  ValueRange reductionVars, TypeRange reductionTypes,
1109  DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
1110  AllRegionPrintArgs args;
1111  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1112  inReductionByref, inReductionSyms);
1113  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1114  /*mapIndices=*/nullptr);
1115  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1116  reductionSyms, reductionMod);
1117  printBlockArgRegion(p, op, region, args);
1118 }
1119 
1120 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1121  ValueRange privateVars, TypeRange privateTypes,
1122  ArrayAttr privateSyms) {
1123  AllRegionPrintArgs args;
1124  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1125  /*mapIndices=*/nullptr);
1126  printBlockArgRegion(p, op, region, args);
1127 }
1128 
1130  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1131  TypeRange privateTypes, ArrayAttr privateSyms,
1132  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1133  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1134  ArrayAttr reductionSyms) {
1135  AllRegionPrintArgs args;
1136  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1137  /*mapIndices=*/nullptr);
1138  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1139  reductionSyms, reductionMod);
1140  printBlockArgRegion(p, op, region, args);
1141 }
1142 
1144  Region &region,
1145  ValueRange taskReductionVars,
1146  TypeRange taskReductionTypes,
1147  DenseBoolArrayAttr taskReductionByref,
1148  ArrayAttr taskReductionSyms) {
1149  AllRegionPrintArgs args;
1150  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1151  taskReductionByref, taskReductionSyms);
1152  printBlockArgRegion(p, op, region, args);
1153 }
1154 
1156  Region &region,
1157  ValueRange useDeviceAddrVars,
1158  TypeRange useDeviceAddrTypes,
1159  ValueRange useDevicePtrVars,
1160  TypeRange useDevicePtrTypes) {
1161  AllRegionPrintArgs args;
1162  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1163  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1164  printBlockArgRegion(p, op, region, args);
1165 }
1166 
1167 /// Verifies Reduction Clause
1168 static LogicalResult
1169 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1170  OperandRange reductionVars,
1171  std::optional<ArrayRef<bool>> reductionByref) {
1172  if (!reductionVars.empty()) {
1173  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1174  return op->emitOpError()
1175  << "expected as many reduction symbol references "
1176  "as reduction variables";
1177  if (reductionByref && reductionByref->size() != reductionVars.size())
1178  return op->emitError() << "expected as many reduction variable by "
1179  "reference attributes as reduction variables";
1180  } else {
1181  if (reductionSyms)
1182  return op->emitOpError() << "unexpected reduction symbol references";
1183  return success();
1184  }
1185 
1186  // TODO: The followings should be done in
1187  // SymbolUserOpInterface::verifySymbolUses.
1188  DenseSet<Value> accumulators;
1189  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1190  Value accum = std::get<0>(args);
1191 
1192  if (!accumulators.insert(accum).second)
1193  return op->emitOpError() << "accumulator variable used more than once";
1194 
1195  Type varType = accum.getType();
1196  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1197  auto decl =
1198  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1199  if (!decl)
1200  return op->emitOpError() << "expected symbol reference " << symbolRef
1201  << " to point to a reduction declaration";
1202 
1203  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1204  return op->emitOpError()
1205  << "expected accumulator (" << varType
1206  << ") to be the same type as reduction declaration ("
1207  << decl.getAccumulatorType() << ")";
1208  }
1209 
1210  return success();
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // Parser, printer and verifier for Copyprivate
1215 //===----------------------------------------------------------------------===//
1216 
1217 /// copyprivate-entry-list ::= copyprivate-entry
1218 /// | copyprivate-entry-list `,` copyprivate-entry
1219 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1220 static ParseResult parseCopyprivate(
1221  OpAsmParser &parser,
1223  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1225  if (failed(parser.parseCommaSeparatedList([&]() {
1226  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1227  parser.parseArrow() ||
1228  parser.parseAttribute(symsVec.emplace_back()) ||
1229  parser.parseColonType(copyprivateTypes.emplace_back()))
1230  return failure();
1231  return success();
1232  })))
1233  return failure();
1234  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1235  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1236  return success();
1237 }
1238 
1239 /// Print Copyprivate clause
1241  OperandRange copyprivateVars,
1242  TypeRange copyprivateTypes,
1243  std::optional<ArrayAttr> copyprivateSyms) {
1244  if (!copyprivateSyms.has_value())
1245  return;
1246  llvm::interleaveComma(
1247  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1248  [&](const auto &args) {
1249  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1250  << std::get<2>(args);
1251  });
1252 }
1253 
1254 /// Verifies CopyPrivate Clause
1255 static LogicalResult
1257  std::optional<ArrayAttr> copyprivateSyms) {
1258  size_t copyprivateSymsSize =
1259  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1260  if (copyprivateSymsSize != copyprivateVars.size())
1261  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1262  << copyprivateVars.size()
1263  << ") and functions (= " << copyprivateSymsSize
1264  << "), both must be equal";
1265  if (!copyprivateSyms.has_value())
1266  return success();
1267 
1268  for (auto copyprivateVarAndSym :
1269  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1270  auto symbolRef =
1271  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1272  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1273  funcOp;
1274  if (mlir::func::FuncOp mlirFuncOp =
1275  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1276  symbolRef))
1277  funcOp = mlirFuncOp;
1278  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1279  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1280  op, symbolRef))
1281  funcOp = llvmFuncOp;
1282 
1283  auto getNumArguments = [&] {
1284  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1285  };
1286 
1287  auto getArgumentType = [&](unsigned i) {
1288  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1289  *funcOp);
1290  };
1291 
1292  if (!funcOp)
1293  return op->emitOpError() << "expected symbol reference " << symbolRef
1294  << " to point to a copy function";
1295 
1296  if (getNumArguments() != 2)
1297  return op->emitOpError()
1298  << "expected copy function " << symbolRef << " to have 2 operands";
1299 
1300  Type argTy = getArgumentType(0);
1301  if (argTy != getArgumentType(1))
1302  return op->emitOpError() << "expected copy function " << symbolRef
1303  << " arguments to have the same type";
1304 
1305  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1306  if (argTy != varType)
1307  return op->emitOpError()
1308  << "expected copy function arguments' type (" << argTy
1309  << ") to be the same as copyprivate variable's type (" << varType
1310  << ")";
1311  }
1312 
1313  return success();
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // Parser, printer and verifier for DependVarList
1318 //===----------------------------------------------------------------------===//
1319 
1320 /// depend-entry-list ::= depend-entry
1321 /// | depend-entry-list `,` depend-entry
1322 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1323 static ParseResult
1326  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1328  if (failed(parser.parseCommaSeparatedList([&]() {
1329  StringRef keyword;
1330  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1331  parser.parseOperand(dependVars.emplace_back()) ||
1332  parser.parseColonType(dependTypes.emplace_back()))
1333  return failure();
1334  if (std::optional<ClauseTaskDepend> keywordDepend =
1335  (symbolizeClauseTaskDepend(keyword)))
1336  kindsVec.emplace_back(
1337  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1338  else
1339  return failure();
1340  return success();
1341  })))
1342  return failure();
1343  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1344  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1345  return success();
1346 }
1347 
1348 /// Print Depend clause
1350  OperandRange dependVars, TypeRange dependTypes,
1351  std::optional<ArrayAttr> dependKinds) {
1352 
1353  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1354  if (i != 0)
1355  p << ", ";
1356  p << stringifyClauseTaskDepend(
1357  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1358  .getValue())
1359  << " -> " << dependVars[i] << " : " << dependTypes[i];
1360  }
1361 }
1362 
1363 /// Verifies Depend clause
1364 static LogicalResult verifyDependVarList(Operation *op,
1365  std::optional<ArrayAttr> dependKinds,
1366  OperandRange dependVars) {
1367  if (!dependVars.empty()) {
1368  if (!dependKinds || dependKinds->size() != dependVars.size())
1369  return op->emitOpError() << "expected as many depend values"
1370  " as depend variables";
1371  } else {
1372  if (dependKinds && !dependKinds->empty())
1373  return op->emitOpError() << "unexpected depend values";
1374  return success();
1375  }
1376 
1377  return success();
1378 }
1379 
1380 //===----------------------------------------------------------------------===//
1381 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1382 //===----------------------------------------------------------------------===//
1383 
1384 /// Parses a Synchronization Hint clause. The value of hint is an integer
1385 /// which is a combination of different hints from `omp_sync_hint_t`.
1386 ///
1387 /// hint-clause = `hint` `(` hint-value `)`
1388 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1389  IntegerAttr &hintAttr) {
1390  StringRef hintKeyword;
1391  int64_t hint = 0;
1392  if (succeeded(parser.parseOptionalKeyword("none"))) {
1393  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1394  return success();
1395  }
1396  auto parseKeyword = [&]() -> ParseResult {
1397  if (failed(parser.parseKeyword(&hintKeyword)))
1398  return failure();
1399  if (hintKeyword == "uncontended")
1400  hint |= 1;
1401  else if (hintKeyword == "contended")
1402  hint |= 2;
1403  else if (hintKeyword == "nonspeculative")
1404  hint |= 4;
1405  else if (hintKeyword == "speculative")
1406  hint |= 8;
1407  else
1408  return parser.emitError(parser.getCurrentLocation())
1409  << hintKeyword << " is not a valid hint";
1410  return success();
1411  };
1412  if (parser.parseCommaSeparatedList(parseKeyword))
1413  return failure();
1414  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1415  return success();
1416 }
1417 
1418 /// Prints a Synchronization Hint clause
1420  IntegerAttr hintAttr) {
1421  int64_t hint = hintAttr.getInt();
1422 
1423  if (hint == 0) {
1424  p << "none";
1425  return;
1426  }
1427 
1428  // Helper function to get n-th bit from the right end of `value`
1429  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1430 
1431  bool uncontended = bitn(hint, 0);
1432  bool contended = bitn(hint, 1);
1433  bool nonspeculative = bitn(hint, 2);
1434  bool speculative = bitn(hint, 3);
1435 
1436  SmallVector<StringRef> hints;
1437  if (uncontended)
1438  hints.push_back("uncontended");
1439  if (contended)
1440  hints.push_back("contended");
1441  if (nonspeculative)
1442  hints.push_back("nonspeculative");
1443  if (speculative)
1444  hints.push_back("speculative");
1445 
1446  llvm::interleaveComma(hints, p);
1447 }
1448 
1449 /// Verifies a synchronization hint clause
1450 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1451 
1452  // Helper function to get n-th bit from the right end of `value`
1453  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1454 
1455  bool uncontended = bitn(hint, 0);
1456  bool contended = bitn(hint, 1);
1457  bool nonspeculative = bitn(hint, 2);
1458  bool speculative = bitn(hint, 3);
1459 
1460  if (uncontended && contended)
1461  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1462  "omp_sync_hint_contended cannot be combined";
1463  if (nonspeculative && speculative)
1464  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1465  "omp_sync_hint_speculative cannot be combined.";
1466  return success();
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // Parser, printer and verifier for Target
1471 //===----------------------------------------------------------------------===//
1472 
1473 // Helper function to get bitwise AND of `value` and 'flag'
1474 uint64_t mapTypeToBitFlag(uint64_t value,
1475  llvm::omp::OpenMPOffloadMappingFlags flag) {
1476  return value & llvm::to_underlying(flag);
1477 }
1478 
1479 /// Parses a map_entries map type from a string format back into its numeric
1480 /// value.
1481 ///
1482 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1483 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1484 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1485  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1486  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1487 
1488  // This simply verifies the correct keyword is read in, the
1489  // keyword itself is stored inside of the operation
1490  auto parseTypeAndMod = [&]() -> ParseResult {
1491  StringRef mapTypeMod;
1492  if (parser.parseKeyword(&mapTypeMod))
1493  return failure();
1494 
1495  if (mapTypeMod == "always")
1496  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1497 
1498  if (mapTypeMod == "implicit")
1499  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1500 
1501  if (mapTypeMod == "ompx_hold")
1502  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1503 
1504  if (mapTypeMod == "close")
1505  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1506 
1507  if (mapTypeMod == "present")
1508  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1509 
1510  if (mapTypeMod == "to")
1511  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1512 
1513  if (mapTypeMod == "from")
1514  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1515 
1516  if (mapTypeMod == "tofrom")
1517  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1518  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1519 
1520  if (mapTypeMod == "delete")
1521  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1522 
1523  return success();
1524  };
1525 
1526  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1527  return failure();
1528 
1529  mapType = parser.getBuilder().getIntegerAttr(
1530  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1531  llvm::to_underlying(mapTypeBits));
1532 
1533  return success();
1534 }
1535 
1536 /// Prints a map_entries map type from its numeric value out into its string
1537 /// format.
1539  IntegerAttr mapType) {
1540  uint64_t mapTypeBits = mapType.getUInt();
1541 
1542  bool emitAllocRelease = true;
1544 
1545  // handling of always, close, present placed at the beginning of the string
1546  // to aid readability
1547  if (mapTypeToBitFlag(mapTypeBits,
1548  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1549  mapTypeStrs.push_back("always");
1550  if (mapTypeToBitFlag(mapTypeBits,
1551  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1552  mapTypeStrs.push_back("implicit");
1553  if (mapTypeToBitFlag(mapTypeBits,
1554  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1555  mapTypeStrs.push_back("ompx_hold");
1556  if (mapTypeToBitFlag(mapTypeBits,
1557  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1558  mapTypeStrs.push_back("close");
1559  if (mapTypeToBitFlag(mapTypeBits,
1560  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1561  mapTypeStrs.push_back("present");
1562 
1563  // special handling of to/from/tofrom/delete and release/alloc, release +
1564  // alloc are the abscense of one of the other flags, whereas tofrom requires
1565  // both the to and from flag to be set.
1566  bool to = mapTypeToBitFlag(mapTypeBits,
1567  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1568  bool from = mapTypeToBitFlag(
1569  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1570  if (to && from) {
1571  emitAllocRelease = false;
1572  mapTypeStrs.push_back("tofrom");
1573  } else if (from) {
1574  emitAllocRelease = false;
1575  mapTypeStrs.push_back("from");
1576  } else if (to) {
1577  emitAllocRelease = false;
1578  mapTypeStrs.push_back("to");
1579  }
1580  if (mapTypeToBitFlag(mapTypeBits,
1581  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1582  emitAllocRelease = false;
1583  mapTypeStrs.push_back("delete");
1584  }
1585  if (emitAllocRelease)
1586  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1587 
1588  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1589  p << mapTypeStrs[i];
1590  if (i + 1 < mapTypeStrs.size()) {
1591  p << ", ";
1592  }
1593  }
1594 }
1595 
1596 static ParseResult parseMembersIndex(OpAsmParser &parser,
1597  ArrayAttr &membersIdx) {
1598  SmallVector<Attribute> values, memberIdxs;
1599 
1600  auto parseIndices = [&]() -> ParseResult {
1601  int64_t value;
1602  if (parser.parseInteger(value))
1603  return failure();
1604  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1605  APInt(64, value, /*isSigned=*/false)));
1606  return success();
1607  };
1608 
1609  do {
1610  if (failed(parser.parseLSquare()))
1611  return failure();
1612 
1613  if (parser.parseCommaSeparatedList(parseIndices))
1614  return failure();
1615 
1616  if (failed(parser.parseRSquare()))
1617  return failure();
1618 
1619  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1620  values.clear();
1621  } while (succeeded(parser.parseOptionalComma()));
1622 
1623  if (!memberIdxs.empty())
1624  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1625 
1626  return success();
1627 }
1628 
1629 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1630  ArrayAttr membersIdx) {
1631  if (!membersIdx)
1632  return;
1633 
1634  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1635  p << "[";
1636  auto memberIdx = cast<ArrayAttr>(v);
1637  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1638  p << cast<IntegerAttr>(v2).getInt();
1639  });
1640  p << "]";
1641  });
1642 }
1643 
1645  VariableCaptureKindAttr mapCaptureType) {
1646  std::string typeCapStr;
1647  llvm::raw_string_ostream typeCap(typeCapStr);
1648  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1649  typeCap << "ByRef";
1650  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1651  typeCap << "ByCopy";
1652  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1653  typeCap << "VLAType";
1654  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1655  typeCap << "This";
1656  p << typeCapStr;
1657 }
1658 
1659 static ParseResult parseCaptureType(OpAsmParser &parser,
1660  VariableCaptureKindAttr &mapCaptureType) {
1661  StringRef mapCaptureKey;
1662  if (parser.parseKeyword(&mapCaptureKey))
1663  return failure();
1664 
1665  if (mapCaptureKey == "This")
1666  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1667  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1668  if (mapCaptureKey == "ByRef")
1669  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1670  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1671  if (mapCaptureKey == "ByCopy")
1672  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1673  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1674  if (mapCaptureKey == "VLAType")
1675  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1676  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1677 
1678  return success();
1679 }
1680 
1681 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1684 
1685  for (auto mapOp : mapVars) {
1686  if (!mapOp.getDefiningOp())
1687  return emitError(op->getLoc(), "missing map operation");
1688 
1689  if (auto mapInfoOp =
1690  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1691  uint64_t mapTypeBits = mapInfoOp.getMapType();
1692 
1693  bool to = mapTypeToBitFlag(
1694  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1695  bool from = mapTypeToBitFlag(
1696  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1697  bool del = mapTypeToBitFlag(
1698  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1699 
1700  bool always = mapTypeToBitFlag(
1701  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1702  bool close = mapTypeToBitFlag(
1703  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1704  bool implicit = mapTypeToBitFlag(
1705  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1706 
1707  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1708  return emitError(op->getLoc(),
1709  "to, from, tofrom and alloc map types are permitted");
1710 
1711  if (isa<TargetEnterDataOp>(op) && (from || del))
1712  return emitError(op->getLoc(), "to and alloc map types are permitted");
1713 
1714  if (isa<TargetExitDataOp>(op) && to)
1715  return emitError(op->getLoc(),
1716  "from, release and delete map types are permitted");
1717 
1718  if (isa<TargetUpdateOp>(op)) {
1719  if (del) {
1720  return emitError(op->getLoc(),
1721  "at least one of to or from map types must be "
1722  "specified, other map types are not permitted");
1723  }
1724 
1725  if (!to && !from) {
1726  return emitError(op->getLoc(),
1727  "at least one of to or from map types must be "
1728  "specified, other map types are not permitted");
1729  }
1730 
1731  auto updateVar = mapInfoOp.getVarPtr();
1732 
1733  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1734  (from && updateToVars.contains(updateVar))) {
1735  return emitError(
1736  op->getLoc(),
1737  "either to or from map types can be specified, not both");
1738  }
1739 
1740  if (always || close || implicit) {
1741  return emitError(
1742  op->getLoc(),
1743  "present, mapper and iterator map type modifiers are permitted");
1744  }
1745 
1746  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1747  }
1748  } else if (!isa<DeclareMapperInfoOp>(op)) {
1749  return emitError(op->getLoc(),
1750  "map argument is not a map entry operation");
1751  }
1752  }
1753 
1754  return success();
1755 }
1756 
1757 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1758  std::optional<DenseI64ArrayAttr> privateMapIndices =
1759  targetOp.getPrivateMapsAttr();
1760 
1761  // None of the private operands are mapped.
1762  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1763  return success();
1764 
1765  OperandRange privateVars = targetOp.getPrivateVars();
1766 
1767  if (privateMapIndices.value().size() !=
1768  static_cast<int64_t>(privateVars.size()))
1769  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1770  "`private_maps` attribute mismatch");
1771 
1772  return success();
1773 }
1774 
1775 //===----------------------------------------------------------------------===//
1776 // MapInfoOp
1777 //===----------------------------------------------------------------------===//
1778 
1779 LogicalResult MapInfoOp::verify() {
1780  if (getMapperId() &&
1781  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1782  *this, getMapperIdAttr())) {
1783  return emitError("invalid mapper id");
1784  }
1785 
1786  return success();
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // TargetDataOp
1791 //===----------------------------------------------------------------------===//
1792 
1793 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1794  const TargetDataOperands &clauses) {
1795  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1796  clauses.mapVars, clauses.useDeviceAddrVars,
1797  clauses.useDevicePtrVars);
1798 }
1799 
1800 LogicalResult TargetDataOp::verify() {
1801  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1802  getUseDeviceAddrVars().empty()) {
1803  return ::emitError(this->getLoc(),
1804  "At least one of map, use_device_ptr_vars, or "
1805  "use_device_addr_vars operand must be present");
1806  }
1807  return verifyMapClause(*this, getMapVars());
1808 }
1809 
1810 //===----------------------------------------------------------------------===//
1811 // TargetEnterDataOp
1812 //===----------------------------------------------------------------------===//
1813 
1814 void TargetEnterDataOp::build(
1815  OpBuilder &builder, OperationState &state,
1816  const TargetEnterExitUpdateDataOperands &clauses) {
1817  MLIRContext *ctx = builder.getContext();
1818  TargetEnterDataOp::build(builder, state,
1819  makeArrayAttr(ctx, clauses.dependKinds),
1820  clauses.dependVars, clauses.device, clauses.ifExpr,
1821  clauses.mapVars, clauses.nowait);
1822 }
1823 
1824 LogicalResult TargetEnterDataOp::verify() {
1825  LogicalResult verifyDependVars =
1826  verifyDependVarList(*this, getDependKinds(), getDependVars());
1827  return failed(verifyDependVars) ? verifyDependVars
1828  : verifyMapClause(*this, getMapVars());
1829 }
1830 
1831 //===----------------------------------------------------------------------===//
1832 // TargetExitDataOp
1833 //===----------------------------------------------------------------------===//
1834 
1835 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1836  const TargetEnterExitUpdateDataOperands &clauses) {
1837  MLIRContext *ctx = builder.getContext();
1838  TargetExitDataOp::build(builder, state,
1839  makeArrayAttr(ctx, clauses.dependKinds),
1840  clauses.dependVars, clauses.device, clauses.ifExpr,
1841  clauses.mapVars, clauses.nowait);
1842 }
1843 
1844 LogicalResult TargetExitDataOp::verify() {
1845  LogicalResult verifyDependVars =
1846  verifyDependVarList(*this, getDependKinds(), getDependVars());
1847  return failed(verifyDependVars) ? verifyDependVars
1848  : verifyMapClause(*this, getMapVars());
1849 }
1850 
1851 //===----------------------------------------------------------------------===//
1852 // TargetUpdateOp
1853 //===----------------------------------------------------------------------===//
1854 
1855 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1856  const TargetEnterExitUpdateDataOperands &clauses) {
1857  MLIRContext *ctx = builder.getContext();
1858  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1859  clauses.dependVars, clauses.device, clauses.ifExpr,
1860  clauses.mapVars, clauses.nowait);
1861 }
1862 
1863 LogicalResult TargetUpdateOp::verify() {
1864  LogicalResult verifyDependVars =
1865  verifyDependVarList(*this, getDependKinds(), getDependVars());
1866  return failed(verifyDependVars) ? verifyDependVars
1867  : verifyMapClause(*this, getMapVars());
1868 }
1869 
1870 //===----------------------------------------------------------------------===//
1871 // TargetOp
1872 //===----------------------------------------------------------------------===//
1873 
1874 void TargetOp::build(OpBuilder &builder, OperationState &state,
1875  const TargetOperands &clauses) {
1876  MLIRContext *ctx = builder.getContext();
1877  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1878  // inReductionByref, inReductionSyms.
1879  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1880  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1881  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1882  clauses.hostEvalVars, clauses.ifExpr,
1883  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1884  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1885  clauses.mapVars, clauses.nowait, clauses.privateVars,
1886  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1887  /*private_maps=*/nullptr);
1888 }
1889 
1890 LogicalResult TargetOp::verify() {
1891  LogicalResult verifyDependVars =
1892  verifyDependVarList(*this, getDependKinds(), getDependVars());
1893 
1894  if (failed(verifyDependVars))
1895  return verifyDependVars;
1896 
1897  LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1898 
1899  if (failed(verifyMapVars))
1900  return verifyMapVars;
1901 
1902  return verifyPrivateVarsMapping(*this);
1903 }
1904 
1905 LogicalResult TargetOp::verifyRegions() {
1906  auto teamsOps = getOps<TeamsOp>();
1907  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1908  return emitError("target containing multiple 'omp.teams' nested ops");
1909 
1910  // Check that host_eval values are only used in legal ways.
1911  Operation *capturedOp = getInnermostCapturedOmpOp();
1912  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1913  for (Value hostEvalArg :
1914  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1915  for (Operation *user : hostEvalArg.getUsers()) {
1916  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1917  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1918  teamsOp.getNumTeamsUpper(),
1919  teamsOp.getThreadLimit()},
1920  hostEvalArg))
1921  continue;
1922 
1923  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1924  "and 'thread_limit' in 'omp.teams'";
1925  }
1926  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1927  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1928  parallelOp->isAncestor(capturedOp) &&
1929  hostEvalArg == parallelOp.getNumThreads())
1930  continue;
1931 
1932  return emitOpError()
1933  << "host_eval argument only legal as 'num_threads' in "
1934  "'omp.parallel' when representing target SPMD";
1935  }
1936  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1937  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
1938  loopNestOp.getOperation() == capturedOp &&
1939  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1940  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1941  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1942  continue;
1943 
1944  return emitOpError() << "host_eval argument only legal as loop bounds "
1945  "and steps in 'omp.loop_nest' when trip count "
1946  "must be evaluated in the host";
1947  }
1948 
1949  return emitOpError() << "host_eval argument illegal use in '"
1950  << user->getName() << "' operation";
1951  }
1952  }
1953  return success();
1954 }
1955 
1956 static Operation *
1957 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1958  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
1959  assert(rootOp && "expected valid operation");
1960 
1961  Dialect *ompDialect = rootOp->getDialect();
1962  Operation *capturedOp = nullptr;
1963  DominanceInfo domInfo;
1964 
1965  // Process in pre-order to check operations from outermost to innermost,
1966  // ensuring we only enter the region of an operation if it meets the criteria
1967  // for being captured. We stop the exploration of nested operations as soon as
1968  // we process a region holding no operations to be captured.
1969  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
1970  if (op == rootOp)
1971  return WalkResult::advance();
1972 
1973  // Ignore operations of other dialects or omp operations with no regions,
1974  // because these will only be checked if they are siblings of an omp
1975  // operation that can potentially be captured.
1976  bool isOmpDialect = op->getDialect() == ompDialect;
1977  bool hasRegions = op->getNumRegions() > 0;
1978  if (!isOmpDialect || !hasRegions)
1979  return WalkResult::skip();
1980 
1981  // This operation cannot be captured if it can be executed more than once
1982  // (i.e. its block's successors can reach it) or if it's not guaranteed to
1983  // be executed before all exits of the region (i.e. it doesn't dominate all
1984  // blocks with no successors reachable from the entry block).
1985  if (checkSingleMandatoryExec) {
1986  Region *parentRegion = op->getParentRegion();
1987  Block *parentBlock = op->getBlock();
1988 
1989  for (Block *successor : parentBlock->getSuccessors())
1990  if (successor->isReachable(parentBlock))
1991  return WalkResult::interrupt();
1992 
1993  for (Block &block : *parentRegion)
1994  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1995  !domInfo.dominates(parentBlock, &block))
1996  return WalkResult::interrupt();
1997  }
1998 
1999  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2000  // into nested operations.
2001  for (Operation &sibling : op->getParentRegion()->getOps())
2002  if (&sibling != op && !siblingAllowedFn(&sibling))
2003  return WalkResult::interrupt();
2004 
2005  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2006  // Otherwise, process the contents of this operation.
2007  capturedOp = op;
2008  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2009  : WalkResult::advance();
2010  });
2011 
2012  return capturedOp;
2013 }
2014 
2015 Operation *TargetOp::getInnermostCapturedOmpOp() {
2016  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2017 
2018  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2019  // effects, but don't include a memory write effect.
2020  return findCapturedOmpOp(
2021  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2022  if (!sibling)
2023  return false;
2024 
2025  if (ompDialect == sibling->getDialect())
2026  return sibling->hasTrait<OpTrait::IsTerminator>();
2027 
2028  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2030  effects;
2031  memOp.getEffects(effects);
2032  return !llvm::any_of(
2033  effects, [&](MemoryEffects::EffectInstance &effect) {
2034  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2035  isa<SideEffects::AutomaticAllocationScopeResource>(
2036  effect.getResource());
2037  });
2038  }
2039  return true;
2040  });
2041 }
2042 
2043 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2044  // A non-null captured op is only valid if it resides inside of a TargetOp
2045  // and is the result of calling getInnermostCapturedOmpOp() on it.
2046  TargetOp targetOp =
2047  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2048  assert((!capturedOp ||
2049  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2050  "unexpected captured op");
2051 
2052  // If it's not capturing a loop, it's a default target region.
2053  if (!isa_and_present<LoopNestOp>(capturedOp))
2054  return TargetRegionFlags::generic;
2055 
2056  // Get the innermost non-simd loop wrapper.
2057  SmallVector<LoopWrapperInterface> loopWrappers;
2058  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2059  assert(!loopWrappers.empty());
2060 
2061  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2062  if (isa<SimdOp>(innermostWrapper))
2063  innermostWrapper = std::next(innermostWrapper);
2064 
2065  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2066  if (numWrappers != 1 && numWrappers != 2)
2067  return TargetRegionFlags::generic;
2068 
2069  // Detect target-teams-distribute-parallel-wsloop[-simd].
2070  if (numWrappers == 2) {
2071  if (!isa<WsloopOp>(innermostWrapper))
2072  return TargetRegionFlags::generic;
2073 
2074  innermostWrapper = std::next(innermostWrapper);
2075  if (!isa<DistributeOp>(innermostWrapper))
2076  return TargetRegionFlags::generic;
2077 
2078  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2079  if (!isa_and_present<ParallelOp>(parallelOp))
2080  return TargetRegionFlags::generic;
2081 
2082  Operation *teamsOp = parallelOp->getParentOp();
2083  if (!isa_and_present<TeamsOp>(teamsOp))
2084  return TargetRegionFlags::generic;
2085 
2086  if (teamsOp->getParentOp() == targetOp.getOperation())
2087  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2088  }
2089  // Detect target-teams-distribute[-simd] and target-teams-loop.
2090  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2091  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2092  if (!isa_and_present<TeamsOp>(teamsOp))
2093  return TargetRegionFlags::generic;
2094 
2095  if (teamsOp->getParentOp() != targetOp.getOperation())
2096  return TargetRegionFlags::generic;
2097 
2098  if (isa<LoopOp>(innermostWrapper))
2099  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2100 
2101  // Find single immediately nested captured omp.parallel and add spmd flag
2102  // (generic-spmd case).
2103  //
2104  // TODO: This shouldn't have to be done here, as it is too easy to break.
2105  // The openmp-opt pass should be updated to be able to promote kernels like
2106  // this from "Generic" to "Generic-SPMD". However, the use of the
2107  // `kmpc_distribute_static_loop` family of functions produced by the
2108  // OMPIRBuilder for these kernels prevents that from working.
2109  Dialect *ompDialect = targetOp->getDialect();
2110  Operation *nestedCapture = findCapturedOmpOp(
2111  capturedOp, /*checkSingleMandatoryExec=*/false,
2112  [&](Operation *sibling) {
2113  return sibling && (ompDialect != sibling->getDialect() ||
2114  sibling->hasTrait<OpTrait::IsTerminator>());
2115  });
2116 
2117  TargetRegionFlags result =
2118  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2119 
2120  if (!nestedCapture)
2121  return result;
2122 
2123  while (nestedCapture->getParentOp() != capturedOp)
2124  nestedCapture = nestedCapture->getParentOp();
2125 
2126  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127  : result;
2128  }
2129  // Detect target-parallel-wsloop[-simd].
2130  else if (isa<WsloopOp>(innermostWrapper)) {
2131  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2132  if (!isa_and_present<ParallelOp>(parallelOp))
2133  return TargetRegionFlags::generic;
2134 
2135  if (parallelOp->getParentOp() == targetOp.getOperation())
2136  return TargetRegionFlags::spmd;
2137  }
2138 
2139  return TargetRegionFlags::generic;
2140 }
2141 
2142 //===----------------------------------------------------------------------===//
2143 // ParallelOp
2144 //===----------------------------------------------------------------------===//
2145 
2146 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2147  ArrayRef<NamedAttribute> attributes) {
2148  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2149  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2150  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2151  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
2152  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2153  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2154  state.addAttributes(attributes);
2155 }
2156 
2157 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2158  const ParallelOperands &clauses) {
2159  MLIRContext *ctx = builder.getContext();
2160  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2161  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2162  makeArrayAttr(ctx, clauses.privateSyms),
2163  clauses.procBindKind, clauses.reductionMod,
2164  clauses.reductionVars,
2165  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2166  makeArrayAttr(ctx, clauses.reductionSyms));
2167 }
2168 
2169 template <typename OpType>
2170 static LogicalResult verifyPrivateVarList(OpType &op) {
2171  auto privateVars = op.getPrivateVars();
2172  auto privateSyms = op.getPrivateSymsAttr();
2173 
2174  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2175  return success();
2176 
2177  auto numPrivateVars = privateVars.size();
2178  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2179 
2180  if (numPrivateVars != numPrivateSyms)
2181  return op.emitError() << "inconsistent number of private variables and "
2182  "privatizer op symbols, private vars: "
2183  << numPrivateVars
2184  << " vs. privatizer op symbols: " << numPrivateSyms;
2185 
2186  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2187  Type varType = std::get<0>(privateVarInfo).getType();
2188  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2189  PrivateClauseOp privatizerOp =
2190  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2191 
2192  if (privatizerOp == nullptr)
2193  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2194  << privateSym << "'";
2195 
2196  Type privatizerType = privatizerOp.getArgType();
2197 
2198  if (privatizerType && (varType != privatizerType))
2199  return op.emitError()
2200  << "type mismatch between a "
2201  << (privatizerOp.getDataSharingType() ==
2202  DataSharingClauseType::Private
2203  ? "private"
2204  : "firstprivate")
2205  << " variable and its privatizer op, var type: " << varType
2206  << " vs. privatizer op type: " << privatizerType;
2207  }
2208 
2209  return success();
2210 }
2211 
2212 LogicalResult ParallelOp::verify() {
2213  if (getAllocateVars().size() != getAllocatorVars().size())
2214  return emitError(
2215  "expected equal sizes for allocate and allocator variables");
2216 
2217  if (failed(verifyPrivateVarList(*this)))
2218  return failure();
2219 
2220  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2221  getReductionByref());
2222 }
2223 
2224 LogicalResult ParallelOp::verifyRegions() {
2225  auto distChildOps = getOps<DistributeOp>();
2226  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2227  if (numDistChildOps > 1)
2228  return emitError()
2229  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2230 
2231  if (numDistChildOps == 1) {
2232  if (!isComposite())
2233  return emitError()
2234  << "'omp.composite' attribute missing from composite operation";
2235 
2236  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2237  Operation &distributeOp = **distChildOps.begin();
2238  for (Operation &childOp : getOps()) {
2239  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2240  continue;
2241 
2242  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2243  return emitError() << "unexpected OpenMP operation inside of composite "
2244  "'omp.parallel': "
2245  << childOp.getName();
2246  }
2247  } else if (isComposite()) {
2248  return emitError()
2249  << "'omp.composite' attribute present in non-composite operation";
2250  }
2251  return success();
2252 }
2253 
2254 //===----------------------------------------------------------------------===//
2255 // TeamsOp
2256 //===----------------------------------------------------------------------===//
2257 
2259  while ((op = op->getParentOp()))
2260  if (isa<OpenMPDialect>(op->getDialect()))
2261  return false;
2262  return true;
2263 }
2264 
2265 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2266  const TeamsOperands &clauses) {
2267  MLIRContext *ctx = builder.getContext();
2268  // TODO Store clauses in op: privateVars, privateSyms.
2269  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2270  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2271  /*private_vars=*/{}, /*private_syms=*/nullptr,
2272  clauses.reductionMod, clauses.reductionVars,
2273  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2274  makeArrayAttr(ctx, clauses.reductionSyms),
2275  clauses.threadLimit);
2276 }
2277 
2278 LogicalResult TeamsOp::verify() {
2279  // Check parent region
2280  // TODO If nested inside of a target region, also check that it does not
2281  // contain any statements, declarations or directives other than this
2282  // omp.teams construct. The issue is how to support the initialization of
2283  // this operation's own arguments (allow SSA values across omp.target?).
2284  Operation *op = getOperation();
2285  if (!isa<TargetOp>(op->getParentOp()) &&
2287  return emitError("expected to be nested inside of omp.target or not nested "
2288  "in any OpenMP dialect operations");
2289 
2290  // Check for num_teams clause restrictions
2291  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2292  auto numTeamsUpperBound = getNumTeamsUpper();
2293  if (!numTeamsUpperBound)
2294  return emitError("expected num_teams upper bound to be defined if the "
2295  "lower bound is defined");
2296  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2297  return emitError(
2298  "expected num_teams upper bound and lower bound to be the same type");
2299  }
2300 
2301  // Check for allocate clause restrictions
2302  if (getAllocateVars().size() != getAllocatorVars().size())
2303  return emitError(
2304  "expected equal sizes for allocate and allocator variables");
2305 
2306  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2307  getReductionByref());
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // SectionOp
2312 //===----------------------------------------------------------------------===//
2313 
2314 OperandRange SectionOp::getPrivateVars() {
2315  return getParentOp().getPrivateVars();
2316 }
2317 
2318 OperandRange SectionOp::getReductionVars() {
2319  return getParentOp().getReductionVars();
2320 }
2321 
2322 //===----------------------------------------------------------------------===//
2323 // SectionsOp
2324 //===----------------------------------------------------------------------===//
2325 
2326 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2327  const SectionsOperands &clauses) {
2328  MLIRContext *ctx = builder.getContext();
2329  // TODO Store clauses in op: privateVars, privateSyms.
2330  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2331  clauses.nowait, /*private_vars=*/{},
2332  /*private_syms=*/nullptr, clauses.reductionMod,
2333  clauses.reductionVars,
2334  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2335  makeArrayAttr(ctx, clauses.reductionSyms));
2336 }
2337 
2338 LogicalResult SectionsOp::verify() {
2339  if (getAllocateVars().size() != getAllocatorVars().size())
2340  return emitError(
2341  "expected equal sizes for allocate and allocator variables");
2342 
2343  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2344  getReductionByref());
2345 }
2346 
2347 LogicalResult SectionsOp::verifyRegions() {
2348  for (auto &inst : *getRegion().begin()) {
2349  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2350  return emitOpError()
2351  << "expected omp.section op or terminator op inside region";
2352  }
2353  }
2354 
2355  return success();
2356 }
2357 
2358 //===----------------------------------------------------------------------===//
2359 // SingleOp
2360 //===----------------------------------------------------------------------===//
2361 
2362 void SingleOp::build(OpBuilder &builder, OperationState &state,
2363  const SingleOperands &clauses) {
2364  MLIRContext *ctx = builder.getContext();
2365  // TODO Store clauses in op: privateVars, privateSyms.
2366  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2367  clauses.copyprivateVars,
2368  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2369  /*private_vars=*/{}, /*private_syms=*/nullptr);
2370 }
2371 
2372 LogicalResult SingleOp::verify() {
2373  // Check for allocate clause restrictions
2374  if (getAllocateVars().size() != getAllocatorVars().size())
2375  return emitError(
2376  "expected equal sizes for allocate and allocator variables");
2377 
2378  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2379  getCopyprivateSyms());
2380 }
2381 
2382 //===----------------------------------------------------------------------===//
2383 // WorkshareOp
2384 //===----------------------------------------------------------------------===//
2385 
2386 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2387  const WorkshareOperands &clauses) {
2388  WorkshareOp::build(builder, state, clauses.nowait);
2389 }
2390 
2391 //===----------------------------------------------------------------------===//
2392 // WorkshareLoopWrapperOp
2393 //===----------------------------------------------------------------------===//
2394 
2395 LogicalResult WorkshareLoopWrapperOp::verify() {
2396  if (!(*this)->getParentOfType<WorkshareOp>())
2397  return emitOpError() << "must be nested in an omp.workshare";
2398  return success();
2399 }
2400 
2401 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2402  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2403  getNestedWrapper())
2404  return emitOpError() << "expected to be a standalone loop wrapper";
2405 
2406  return success();
2407 }
2408 
2409 //===----------------------------------------------------------------------===//
2410 // LoopWrapperInterface
2411 //===----------------------------------------------------------------------===//
2412 
2413 LogicalResult LoopWrapperInterface::verifyImpl() {
2414  Operation *op = this->getOperation();
2415  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2417  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2418  "and `SingleBlock` traits";
2419 
2420  if (op->getNumRegions() != 1)
2421  return emitOpError() << "loop wrapper does not contain exactly one region";
2422 
2423  Region &region = op->getRegion(0);
2424  if (range_size(region.getOps()) != 1)
2425  return emitOpError()
2426  << "loop wrapper does not contain exactly one nested op";
2427 
2428  Operation &firstOp = *region.op_begin();
2429  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2430  return emitOpError() << "nested in loop wrapper is not another loop "
2431  "wrapper or `omp.loop_nest`";
2432 
2433  return success();
2434 }
2435 
2436 //===----------------------------------------------------------------------===//
2437 // LoopOp
2438 //===----------------------------------------------------------------------===//
2439 
2440 void LoopOp::build(OpBuilder &builder, OperationState &state,
2441  const LoopOperands &clauses) {
2442  MLIRContext *ctx = builder.getContext();
2443 
2444  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2445  makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
2446  clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2447  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2448  makeArrayAttr(ctx, clauses.reductionSyms));
2449 }
2450 
2451 LogicalResult LoopOp::verify() {
2452  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2453  getReductionByref());
2454 }
2455 
2456 LogicalResult LoopOp::verifyRegions() {
2457  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2458  getNestedWrapper())
2459  return emitOpError() << "expected to be a standalone loop wrapper";
2460 
2461  return success();
2462 }
2463 
2464 //===----------------------------------------------------------------------===//
2465 // WsloopOp
2466 //===----------------------------------------------------------------------===//
2467 
2468 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2469  ArrayRef<NamedAttribute> attributes) {
2470  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2471  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2472  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2473  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2474  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2475  /*reduction_byref=*/nullptr,
2476  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2477  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2478  /*schedule_simd=*/false);
2479  state.addAttributes(attributes);
2480 }
2481 
2482 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2483  const WsloopOperands &clauses) {
2484  MLIRContext *ctx = builder.getContext();
2485  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
2486  // privateSyms.
2487  WsloopOp::build(builder, state,
2488  /*allocate_vars=*/{}, /*allocator_vars=*/{},
2489  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2490  clauses.order, clauses.orderMod, clauses.ordered,
2491  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2492  clauses.reductionMod, clauses.reductionVars,
2493  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2494  makeArrayAttr(ctx, clauses.reductionSyms),
2495  clauses.scheduleKind, clauses.scheduleChunk,
2496  clauses.scheduleMod, clauses.scheduleSimd);
2497 }
2498 
2499 LogicalResult WsloopOp::verify() {
2500  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2501  getReductionByref());
2502 }
2503 
2504 LogicalResult WsloopOp::verifyRegions() {
2505  bool isCompositeChildLeaf =
2506  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2507 
2508  if (LoopWrapperInterface nested = getNestedWrapper()) {
2509  if (!isComposite())
2510  return emitError()
2511  << "'omp.composite' attribute missing from composite wrapper";
2512 
2513  // Check for the allowed leaf constructs that may appear in a composite
2514  // construct directly after DO/FOR.
2515  if (!isa<SimdOp>(nested))
2516  return emitError() << "only supported nested wrapper is 'omp.simd'";
2517 
2518  } else if (isComposite() && !isCompositeChildLeaf) {
2519  return emitError()
2520  << "'omp.composite' attribute present in non-composite wrapper";
2521  } else if (!isComposite() && isCompositeChildLeaf) {
2522  return emitError()
2523  << "'omp.composite' attribute missing from composite wrapper";
2524  }
2525 
2526  return success();
2527 }
2528 
2529 //===----------------------------------------------------------------------===//
2530 // Simd construct [2.9.3.1]
2531 //===----------------------------------------------------------------------===//
2532 
2533 void SimdOp::build(OpBuilder &builder, OperationState &state,
2534  const SimdOperands &clauses) {
2535  MLIRContext *ctx = builder.getContext();
2536  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2537  // privateSyms.
2538  SimdOp::build(builder, state, clauses.alignedVars,
2539  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2540  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2541  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2542  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2543  clauses.reductionMod, clauses.reductionVars,
2544  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2545  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2546  clauses.simdlen);
2547 }
2548 
2549 LogicalResult SimdOp::verify() {
2550  if (getSimdlen().has_value() && getSafelen().has_value() &&
2551  getSimdlen().value() > getSafelen().value())
2552  return emitOpError()
2553  << "simdlen clause and safelen clause are both present, but the "
2554  "simdlen value is not less than or equal to safelen value";
2555 
2556  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2557  return failure();
2558 
2559  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2560  return failure();
2561 
2562  bool isCompositeChildLeaf =
2563  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2564 
2565  if (!isComposite() && isCompositeChildLeaf)
2566  return emitError()
2567  << "'omp.composite' attribute missing from composite wrapper";
2568 
2569  if (isComposite() && !isCompositeChildLeaf)
2570  return emitError()
2571  << "'omp.composite' attribute present in non-composite wrapper";
2572 
2573  return success();
2574 }
2575 
2576 LogicalResult SimdOp::verifyRegions() {
2577  if (getNestedWrapper())
2578  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2579 
2580  return success();
2581 }
2582 
2583 //===----------------------------------------------------------------------===//
2584 // Distribute construct [2.9.4.1]
2585 //===----------------------------------------------------------------------===//
2586 
2587 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2588  const DistributeOperands &clauses) {
2589  DistributeOp::build(builder, state, clauses.allocateVars,
2590  clauses.allocatorVars, clauses.distScheduleStatic,
2591  clauses.distScheduleChunkSize, clauses.order,
2592  clauses.orderMod, clauses.privateVars,
2593  makeArrayAttr(builder.getContext(), clauses.privateSyms));
2594 }
2595 
2596 LogicalResult DistributeOp::verify() {
2597  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2598  return emitOpError() << "chunk size set without "
2599  "dist_schedule_static being present";
2600 
2601  if (getAllocateVars().size() != getAllocatorVars().size())
2602  return emitError(
2603  "expected equal sizes for allocate and allocator variables");
2604 
2605  return success();
2606 }
2607 
2608 LogicalResult DistributeOp::verifyRegions() {
2609  if (LoopWrapperInterface nested = getNestedWrapper()) {
2610  if (!isComposite())
2611  return emitError()
2612  << "'omp.composite' attribute missing from composite wrapper";
2613  // Check for the allowed leaf constructs that may appear in a composite
2614  // construct directly after DISTRIBUTE.
2615  if (isa<WsloopOp>(nested)) {
2616  Operation *parentOp = (*this)->getParentOp();
2617  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2618  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2619  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2620  "when a composite 'omp.parallel' is the direct "
2621  "parent";
2622  }
2623  } else if (!isa<SimdOp>(nested))
2624  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2625  "'omp.wsloop'";
2626  } else if (isComposite()) {
2627  return emitError()
2628  << "'omp.composite' attribute present in non-composite wrapper";
2629  }
2630 
2631  return success();
2632 }
2633 
2634 //===----------------------------------------------------------------------===//
2635 // DeclareMapperOp / DeclareMapperInfoOp
2636 //===----------------------------------------------------------------------===//
2637 
2638 LogicalResult DeclareMapperInfoOp::verify() {
2639  return verifyMapClause(*this, getMapVars());
2640 }
2641 
2642 LogicalResult DeclareMapperOp::verifyRegions() {
2643  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2644  getRegion().getBlocks().front().getTerminator()))
2645  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2646 
2647  return success();
2648 }
2649 
2650 //===----------------------------------------------------------------------===//
2651 // DeclareReductionOp
2652 //===----------------------------------------------------------------------===//
2653 
2654 LogicalResult DeclareReductionOp::verifyRegions() {
2655  if (!getAllocRegion().empty()) {
2656  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2657  if (yieldOp.getResults().size() != 1 ||
2658  yieldOp.getResults().getTypes()[0] != getType())
2659  return emitOpError() << "expects alloc region to yield a value "
2660  "of the reduction type";
2661  }
2662  }
2663 
2664  if (getInitializerRegion().empty())
2665  return emitOpError() << "expects non-empty initializer region";
2666  Block &initializerEntryBlock = getInitializerRegion().front();
2667 
2668  if (initializerEntryBlock.getNumArguments() == 1) {
2669  if (!getAllocRegion().empty())
2670  return emitOpError() << "expects two arguments to the initializer region "
2671  "when an allocation region is used";
2672  } else if (initializerEntryBlock.getNumArguments() == 2) {
2673  if (getAllocRegion().empty())
2674  return emitOpError() << "expects one argument to the initializer region "
2675  "when no allocation region is used";
2676  } else {
2677  return emitOpError()
2678  << "expects one or two arguments to the initializer region";
2679  }
2680 
2681  for (mlir::Value arg : initializerEntryBlock.getArguments())
2682  if (arg.getType() != getType())
2683  return emitOpError() << "expects initializer region argument to match "
2684  "the reduction type";
2685 
2686  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2687  if (yieldOp.getResults().size() != 1 ||
2688  yieldOp.getResults().getTypes()[0] != getType())
2689  return emitOpError() << "expects initializer region to yield a value "
2690  "of the reduction type";
2691  }
2692 
2693  if (getReductionRegion().empty())
2694  return emitOpError() << "expects non-empty reduction region";
2695  Block &reductionEntryBlock = getReductionRegion().front();
2696  if (reductionEntryBlock.getNumArguments() != 2 ||
2697  reductionEntryBlock.getArgumentTypes()[0] !=
2698  reductionEntryBlock.getArgumentTypes()[1] ||
2699  reductionEntryBlock.getArgumentTypes()[0] != getType())
2700  return emitOpError() << "expects reduction region with two arguments of "
2701  "the reduction type";
2702  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2703  if (yieldOp.getResults().size() != 1 ||
2704  yieldOp.getResults().getTypes()[0] != getType())
2705  return emitOpError() << "expects reduction region to yield a value "
2706  "of the reduction type";
2707  }
2708 
2709  if (!getAtomicReductionRegion().empty()) {
2710  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2711  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2712  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2713  atomicReductionEntryBlock.getArgumentTypes()[1])
2714  return emitOpError() << "expects atomic reduction region with two "
2715  "arguments of the same type";
2716  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2717  atomicReductionEntryBlock.getArgumentTypes()[0]);
2718  if (!ptrType ||
2719  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2720  return emitOpError() << "expects atomic reduction region arguments to "
2721  "be accumulators containing the reduction type";
2722  }
2723 
2724  if (getCleanupRegion().empty())
2725  return success();
2726  Block &cleanupEntryBlock = getCleanupRegion().front();
2727  if (cleanupEntryBlock.getNumArguments() != 1 ||
2728  cleanupEntryBlock.getArgument(0).getType() != getType())
2729  return emitOpError() << "expects cleanup region with one argument "
2730  "of the reduction type";
2731 
2732  return success();
2733 }
2734 
2735 //===----------------------------------------------------------------------===//
2736 // TaskOp
2737 //===----------------------------------------------------------------------===//
2738 
2739 void TaskOp::build(OpBuilder &builder, OperationState &state,
2740  const TaskOperands &clauses) {
2741  MLIRContext *ctx = builder.getContext();
2742  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2743  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2744  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2745  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2746  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2747  clauses.priority, /*private_vars=*/clauses.privateVars,
2748  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2749  clauses.untied, clauses.eventHandle);
2750 }
2751 
2752 LogicalResult TaskOp::verify() {
2753  LogicalResult verifyDependVars =
2754  verifyDependVarList(*this, getDependKinds(), getDependVars());
2755  return failed(verifyDependVars)
2756  ? verifyDependVars
2757  : verifyReductionVarList(*this, getInReductionSyms(),
2758  getInReductionVars(),
2759  getInReductionByref());
2760 }
2761 
2762 //===----------------------------------------------------------------------===//
2763 // TaskgroupOp
2764 //===----------------------------------------------------------------------===//
2765 
2766 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2767  const TaskgroupOperands &clauses) {
2768  MLIRContext *ctx = builder.getContext();
2769  TaskgroupOp::build(builder, state, clauses.allocateVars,
2770  clauses.allocatorVars, clauses.taskReductionVars,
2771  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2772  makeArrayAttr(ctx, clauses.taskReductionSyms));
2773 }
2774 
2775 LogicalResult TaskgroupOp::verify() {
2776  return verifyReductionVarList(*this, getTaskReductionSyms(),
2777  getTaskReductionVars(),
2778  getTaskReductionByref());
2779 }
2780 
2781 //===----------------------------------------------------------------------===//
2782 // TaskloopOp
2783 //===----------------------------------------------------------------------===//
2784 
2785 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2786  const TaskloopOperands &clauses) {
2787  MLIRContext *ctx = builder.getContext();
2788  // TODO Store clauses in op: privateVars, privateSyms.
2789  TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2790  clauses.final, clauses.grainsizeMod, clauses.grainsize,
2791  clauses.ifExpr, clauses.inReductionVars,
2792  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2793  makeArrayAttr(ctx, clauses.inReductionSyms),
2794  clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
2795  clauses.numTasks, clauses.priority, /*private_vars=*/{},
2796  /*private_syms=*/nullptr, clauses.reductionMod,
2797  clauses.reductionVars,
2798  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2799  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2800 }
2801 
2802 LogicalResult TaskloopOp::verify() {
2803  if (getAllocateVars().size() != getAllocatorVars().size())
2804  return emitError(
2805  "expected equal sizes for allocate and allocator variables");
2806  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2807  getReductionVars(), getReductionByref())) ||
2808  failed(verifyReductionVarList(*this, getInReductionSyms(),
2809  getInReductionVars(),
2810  getInReductionByref())))
2811  return failure();
2812 
2813  if (!getReductionVars().empty() && getNogroup())
2814  return emitError("if a reduction clause is present on the taskloop "
2815  "directive, the nogroup clause must not be specified");
2816  for (auto var : getReductionVars()) {
2817  if (llvm::is_contained(getInReductionVars(), var))
2818  return emitError("the same list item cannot appear in both a reduction "
2819  "and an in_reduction clause");
2820  }
2821 
2822  if (getGrainsize() && getNumTasks()) {
2823  return emitError(
2824  "the grainsize clause and num_tasks clause are mutually exclusive and "
2825  "may not appear on the same taskloop directive");
2826  }
2827 
2828  return success();
2829 }
2830 
2831 LogicalResult TaskloopOp::verifyRegions() {
2832  if (LoopWrapperInterface nested = getNestedWrapper()) {
2833  if (!isComposite())
2834  return emitError()
2835  << "'omp.composite' attribute missing from composite wrapper";
2836 
2837  // Check for the allowed leaf constructs that may appear in a composite
2838  // construct directly after TASKLOOP.
2839  if (!isa<SimdOp>(nested))
2840  return emitError() << "only supported nested wrapper is 'omp.simd'";
2841  } else if (isComposite()) {
2842  return emitError()
2843  << "'omp.composite' attribute present in non-composite wrapper";
2844  }
2845 
2846  return success();
2847 }
2848 
2849 //===----------------------------------------------------------------------===//
2850 // LoopNestOp
2851 //===----------------------------------------------------------------------===//
2852 
2853 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2854  // Parse an opening `(` followed by induction variables followed by `)`
2857  Type loopVarType;
2858  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2859  parser.parseColonType(loopVarType) ||
2860  // Parse loop bounds.
2861  parser.parseEqual() ||
2862  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2863  parser.parseKeyword("to") ||
2864  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2865  return failure();
2866 
2867  for (auto &iv : ivs)
2868  iv.type = loopVarType;
2869 
2870  // Parse "inclusive" flag.
2871  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2872  result.addAttribute("loop_inclusive",
2873  UnitAttr::get(parser.getBuilder().getContext()));
2874 
2875  // Parse step values.
2877  if (parser.parseKeyword("step") ||
2878  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2879  return failure();
2880 
2881  // Parse the body.
2882  Region *region = result.addRegion();
2883  if (parser.parseRegion(*region, ivs))
2884  return failure();
2885 
2886  // Resolve operands.
2887  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2888  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2889  parser.resolveOperands(steps, loopVarType, result.operands))
2890  return failure();
2891 
2892  // Parse the optional attribute list.
2893  return parser.parseOptionalAttrDict(result.attributes);
2894 }
2895 
2897  Region &region = getRegion();
2898  auto args = region.getArguments();
2899  p << " (" << args << ") : " << args[0].getType() << " = ("
2900  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2901  if (getLoopInclusive())
2902  p << "inclusive ";
2903  p << "step (" << getLoopSteps() << ") ";
2904  p.printRegion(region, /*printEntryBlockArgs=*/false);
2905 }
2906 
2907 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2908  const LoopNestOperands &clauses) {
2909  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2910  clauses.loopUpperBounds, clauses.loopSteps,
2911  clauses.loopInclusive);
2912 }
2913 
2914 LogicalResult LoopNestOp::verify() {
2915  if (getLoopLowerBounds().empty())
2916  return emitOpError() << "must represent at least one loop";
2917 
2918  if (getLoopLowerBounds().size() != getIVs().size())
2919  return emitOpError() << "number of range arguments and IVs do not match";
2920 
2921  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2922  if (lb.getType() != iv.getType())
2923  return emitOpError()
2924  << "range argument type does not match corresponding IV type";
2925  }
2926 
2927  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2928  return emitOpError() << "expects parent op to be a loop wrapper";
2929 
2930  return success();
2931 }
2932 
2933 void LoopNestOp::gatherWrappers(
2935  Operation *parent = (*this)->getParentOp();
2936  while (auto wrapper =
2937  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2938  wrappers.push_back(wrapper);
2939  parent = parent->getParentOp();
2940  }
2941 }
2942 
2943 //===----------------------------------------------------------------------===//
2944 // Critical construct (2.17.1)
2945 //===----------------------------------------------------------------------===//
2946 
2947 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2948  const CriticalDeclareOperands &clauses) {
2949  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2950 }
2951 
2952 LogicalResult CriticalDeclareOp::verify() {
2953  return verifySynchronizationHint(*this, getHint());
2954 }
2955 
2956 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2957  if (getNameAttr()) {
2958  SymbolRefAttr symbolRef = getNameAttr();
2959  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2960  *this, symbolRef);
2961  if (!decl) {
2962  return emitOpError() << "expected symbol reference " << symbolRef
2963  << " to point to a critical declaration";
2964  }
2965  }
2966 
2967  return success();
2968 }
2969 
2970 //===----------------------------------------------------------------------===//
2971 // Ordered construct
2972 //===----------------------------------------------------------------------===//
2973 
2974 static LogicalResult verifyOrderedParent(Operation &op) {
2975  bool hasRegion = op.getNumRegions() > 0;
2976  auto loopOp = op.getParentOfType<LoopNestOp>();
2977  if (!loopOp) {
2978  if (hasRegion)
2979  return success();
2980 
2981  // TODO: Consider if this needs to be the case only for the standalone
2982  // variant of the ordered construct.
2983  return op.emitOpError() << "must be nested inside of a loop";
2984  }
2985 
2986  Operation *wrapper = loopOp->getParentOp();
2987  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2988  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2989  if (!orderedAttr)
2990  return op.emitOpError() << "the enclosing worksharing-loop region must "
2991  "have an ordered clause";
2992 
2993  if (hasRegion && orderedAttr.getInt() != 0)
2994  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2995  "have a parameter present";
2996 
2997  if (!hasRegion && orderedAttr.getInt() == 0)
2998  return op.emitOpError() << "the enclosing loop's ordered clause must "
2999  "have a parameter present";
3000  } else if (!isa<SimdOp>(wrapper)) {
3001  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3002  "or worksharing simd loop";
3003  }
3004  return success();
3005 }
3006 
3007 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3008  const OrderedOperands &clauses) {
3009  OrderedOp::build(builder, state, clauses.doacrossDependType,
3010  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3011 }
3012 
3013 LogicalResult OrderedOp::verify() {
3014  if (failed(verifyOrderedParent(**this)))
3015  return failure();
3016 
3017  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3018  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3019  return emitOpError() << "number of variables in depend clause does not "
3020  << "match number of iteration variables in the "
3021  << "doacross loop";
3022 
3023  return success();
3024 }
3025 
3026 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3027  const OrderedRegionOperands &clauses) {
3028  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3029 }
3030 
3031 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3032 
3033 //===----------------------------------------------------------------------===//
3034 // TaskwaitOp
3035 //===----------------------------------------------------------------------===//
3036 
3037 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3038  const TaskwaitOperands &clauses) {
3039  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3040  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3041  /*depend_vars=*/{}, /*nowait=*/nullptr);
3042 }
3043 
3044 //===----------------------------------------------------------------------===//
3045 // Verifier for AtomicReadOp
3046 //===----------------------------------------------------------------------===//
3047 
3048 LogicalResult AtomicReadOp::verify() {
3049  if (verifyCommon().failed())
3050  return mlir::failure();
3051 
3052  if (auto mo = getMemoryOrder()) {
3053  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3054  *mo == ClauseMemoryOrderKind::Release) {
3055  return emitError(
3056  "memory-order must not be acq_rel or release for atomic reads");
3057  }
3058  }
3059  return verifySynchronizationHint(*this, getHint());
3060 }
3061 
3062 //===----------------------------------------------------------------------===//
3063 // Verifier for AtomicWriteOp
3064 //===----------------------------------------------------------------------===//
3065 
3066 LogicalResult AtomicWriteOp::verify() {
3067  if (verifyCommon().failed())
3068  return mlir::failure();
3069 
3070  if (auto mo = getMemoryOrder()) {
3071  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3072  *mo == ClauseMemoryOrderKind::Acquire) {
3073  return emitError(
3074  "memory-order must not be acq_rel or acquire for atomic writes");
3075  }
3076  }
3077  return verifySynchronizationHint(*this, getHint());
3078 }
3079 
3080 //===----------------------------------------------------------------------===//
3081 // Verifier for AtomicUpdateOp
3082 //===----------------------------------------------------------------------===//
3083 
3084 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3085  PatternRewriter &rewriter) {
3086  if (op.isNoOp()) {
3087  rewriter.eraseOp(op);
3088  return success();
3089  }
3090  if (Value writeVal = op.getWriteOpVal()) {
3091  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3092  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3093  return success();
3094  }
3095  return failure();
3096 }
3097 
3098 LogicalResult AtomicUpdateOp::verify() {
3099  if (verifyCommon().failed())
3100  return mlir::failure();
3101 
3102  if (auto mo = getMemoryOrder()) {
3103  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3104  *mo == ClauseMemoryOrderKind::Acquire) {
3105  return emitError(
3106  "memory-order must not be acq_rel or acquire for atomic updates");
3107  }
3108  }
3109 
3110  return verifySynchronizationHint(*this, getHint());
3111 }
3112 
3113 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3114 
3115 //===----------------------------------------------------------------------===//
3116 // Verifier for AtomicCaptureOp
3117 //===----------------------------------------------------------------------===//
3118 
3119 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3120  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3121  return op;
3122  return dyn_cast<AtomicReadOp>(getSecondOp());
3123 }
3124 
3125 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3126  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3127  return op;
3128  return dyn_cast<AtomicWriteOp>(getSecondOp());
3129 }
3130 
3131 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3132  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3133  return op;
3134  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3135 }
3136 
3137 LogicalResult AtomicCaptureOp::verify() {
3138  return verifySynchronizationHint(*this, getHint());
3139 }
3140 
3141 LogicalResult AtomicCaptureOp::verifyRegions() {
3142  if (verifyRegionsCommon().failed())
3143  return mlir::failure();
3144 
3145  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3146  return emitOpError(
3147  "operations inside capture region must not have hint clause");
3148 
3149  if (getFirstOp()->getAttr("memory_order") ||
3150  getSecondOp()->getAttr("memory_order"))
3151  return emitOpError(
3152  "operations inside capture region must not have memory_order clause");
3153  return success();
3154 }
3155 
3156 //===----------------------------------------------------------------------===//
3157 // CancelOp
3158 //===----------------------------------------------------------------------===//
3159 
3160 void CancelOp::build(OpBuilder &builder, OperationState &state,
3161  const CancelOperands &clauses) {
3162  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3163 }
3164 
3166  Operation *parent = thisOp->getParentOp();
3167  while (parent) {
3168  if (parent->getDialect() == thisOp->getDialect())
3169  return parent;
3170  parent = parent->getParentOp();
3171  }
3172  return nullptr;
3173 }
3174 
3175 LogicalResult CancelOp::verify() {
3176  ClauseCancellationConstructType cct = getCancelDirective();
3177  // The next OpenMP operation in the chain of parents
3178  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3179  if (!structuralParent)
3180  return emitOpError() << "Orphaned cancel construct";
3181 
3182  if ((cct == ClauseCancellationConstructType::Parallel) &&
3183  !mlir::isa<ParallelOp>(structuralParent)) {
3184  return emitOpError() << "cancel parallel must appear "
3185  << "inside a parallel region";
3186  }
3187  if (cct == ClauseCancellationConstructType::Loop) {
3188  // structural parent will be omp.loop_nest, directly nested inside
3189  // omp.wsloop
3190  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3191 
3192  if (!wsloopOp) {
3193  return emitOpError()
3194  << "cancel loop must appear inside a worksharing-loop region";
3195  }
3196  if (wsloopOp.getNowaitAttr()) {
3197  return emitError() << "A worksharing construct that is canceled "
3198  << "must not have a nowait clause";
3199  }
3200  if (wsloopOp.getOrderedAttr()) {
3201  return emitError() << "A worksharing construct that is canceled "
3202  << "must not have an ordered clause";
3203  }
3204 
3205  } else if (cct == ClauseCancellationConstructType::Sections) {
3206  // structural parent will be an omp.section, directly nested inside
3207  // omp.sections
3208  auto sectionsOp =
3209  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3210  if (!sectionsOp) {
3211  return emitOpError() << "cancel sections must appear "
3212  << "inside a sections region";
3213  }
3214  if (sectionsOp.getNowait()) {
3215  return emitError() << "A sections construct that is canceled "
3216  << "must not have a nowait clause";
3217  }
3218  }
3219  // TODO : Add more when we support taskgroup.
3220  return success();
3221 }
3222 
3223 //===----------------------------------------------------------------------===//
3224 // CancellationPointOp
3225 //===----------------------------------------------------------------------===//
3226 
3227 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3228  const CancellationPointOperands &clauses) {
3229  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3230 }
3231 
3232 LogicalResult CancellationPointOp::verify() {
3233  ClauseCancellationConstructType cct = getCancelDirective();
3234  // The next OpenMP operation in the chain of parents
3235  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3236  if (!structuralParent)
3237  return emitOpError() << "Orphaned cancellation point";
3238 
3239  if ((cct == ClauseCancellationConstructType::Parallel) &&
3240  !mlir::isa<ParallelOp>(structuralParent)) {
3241  return emitOpError() << "cancellation point parallel must appear "
3242  << "inside a parallel region";
3243  }
3244  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3245  // find the wsloop
3246  if ((cct == ClauseCancellationConstructType::Loop) &&
3247  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3248  return emitOpError() << "cancellation point loop must appear "
3249  << "inside a worksharing-loop region";
3250  }
3251  if ((cct == ClauseCancellationConstructType::Sections) &&
3252  !mlir::isa<omp::SectionOp>(structuralParent)) {
3253  return emitOpError() << "cancellation point sections must appear "
3254  << "inside a sections region";
3255  }
3256  // TODO : Add more when we support taskgroup.
3257  return success();
3258 }
3259 
3260 //===----------------------------------------------------------------------===//
3261 // MapBoundsOp
3262 //===----------------------------------------------------------------------===//
3263 
3264 LogicalResult MapBoundsOp::verify() {
3265  auto extent = getExtent();
3266  auto upperbound = getUpperBound();
3267  if (!extent && !upperbound)
3268  return emitError("expected extent or upperbound.");
3269  return success();
3270 }
3271 
3272 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3273  TypeRange /*result_types*/, StringAttr symName,
3274  TypeAttr type) {
3275  PrivateClauseOp::build(
3276  odsBuilder, odsState, symName, type,
3278  DataSharingClauseType::Private));
3279 }
3280 
3281 LogicalResult PrivateClauseOp::verifyRegions() {
3282  Type argType = getArgType();
3283  auto verifyTerminator = [&](Operation *terminator,
3284  bool yieldsValue) -> LogicalResult {
3285  if (!terminator->getBlock()->getSuccessors().empty())
3286  return success();
3287 
3288  if (!llvm::isa<YieldOp>(terminator))
3289  return mlir::emitError(terminator->getLoc())
3290  << "expected exit block terminator to be an `omp.yield` op.";
3291 
3292  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3293  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3294 
3295  if (!yieldsValue) {
3296  if (yieldedTypes.empty())
3297  return success();
3298 
3299  return mlir::emitError(terminator->getLoc())
3300  << "Did not expect any values to be yielded.";
3301  }
3302 
3303  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3304  return success();
3305 
3306  auto error = mlir::emitError(yieldOp.getLoc())
3307  << "Invalid yielded value. Expected type: " << argType
3308  << ", got: ";
3309 
3310  if (yieldedTypes.empty())
3311  error << "None";
3312  else
3313  error << yieldedTypes;
3314 
3315  return error;
3316  };
3317 
3318  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3319  StringRef regionName,
3320  bool yieldsValue) -> LogicalResult {
3321  assert(!region.empty());
3322 
3323  if (region.getNumArguments() != expectedNumArgs)
3324  return mlir::emitError(region.getLoc())
3325  << "`" << regionName << "`: "
3326  << "expected " << expectedNumArgs
3327  << " region arguments, got: " << region.getNumArguments();
3328 
3329  for (Block &block : region) {
3330  // MLIR will verify the absence of the terminator for us.
3331  if (!block.mightHaveTerminator())
3332  continue;
3333 
3334  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3335  return failure();
3336  }
3337 
3338  return success();
3339  };
3340 
3341  // Ensure all of the region arguments have the same type
3342  for (Region *region : getRegions())
3343  for (Type ty : region->getArgumentTypes())
3344  if (ty != argType)
3345  return emitError() << "Region argument type mismatch: got " << ty
3346  << " expected " << argType << ".";
3347 
3348  mlir::Region &initRegion = getInitRegion();
3349  if (!initRegion.empty() &&
3350  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3351  /*yieldsValue=*/true)))
3352  return failure();
3353 
3354  DataSharingClauseType dsType = getDataSharingType();
3355 
3356  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3357  return emitError("`private` clauses do not require a `copy` region.");
3358 
3359  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3360  return emitError(
3361  "`firstprivate` clauses require at least a `copy` region.");
3362 
3363  if (dsType == DataSharingClauseType::FirstPrivate &&
3364  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3365  /*yieldsValue=*/true)))
3366  return failure();
3367 
3368  if (!getDeallocRegion().empty() &&
3369  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3370  /*yieldsValue=*/false)))
3371  return failure();
3372 
3373  return success();
3374 }
3375 
3376 //===----------------------------------------------------------------------===//
3377 // Spec 5.2: Masked construct (10.5)
3378 //===----------------------------------------------------------------------===//
3379 
3380 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3381  const MaskedOperands &clauses) {
3382  MaskedOp::build(builder, state, clauses.filteredThreadId);
3383 }
3384 
3385 //===----------------------------------------------------------------------===//
3386 // Spec 5.2: Scan construct (5.6)
3387 //===----------------------------------------------------------------------===//
3388 
3389 void ScanOp::build(OpBuilder &builder, OperationState &state,
3390  const ScanOperands &clauses) {
3391  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3392 }
3393 
3394 LogicalResult ScanOp::verify() {
3395  if (hasExclusiveVars() == hasInclusiveVars())
3396  return emitError(
3397  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3398  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3399  if (parentWsLoopOp.getReductionModAttr() &&
3400  parentWsLoopOp.getReductionModAttr().getValue() ==
3401  ReductionModifier::inscan)
3402  return success();
3403  }
3404  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3405  if (parentSimdOp.getReductionModAttr() &&
3406  parentSimdOp.getReductionModAttr().getValue() ==
3407  ReductionModifier::inscan)
3408  return success();
3409  }
3410  return emitError("SCAN directive needs to be enclosed within a parent "
3411  "worksharing loop construct or SIMD construct with INSCAN "
3412  "reduction modifier");
3413 }
3414 
3415 #define GET_ATTRDEF_CLASSES
3416 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3417 
3418 #define GET_OP_CLASSES
3419 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3420 
3421 #define GET_TYPEDEF_CLASSES
3422 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:736
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1286
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr)
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:772
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:768
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:307
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:880
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.