MLIR  22.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/SymbolTable.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/STLForwardCompat.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/ADT/bit.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include <cstddef>
37 #include <iterator>
38 #include <optional>
39 #include <variant>
40 
41 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::omp;
48 
49 static ArrayAttr makeArrayAttr(MLIRContext *context,
51  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52 }
53 
54 static DenseBoolArrayAttr
56  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57 }
58 
59 static DenseI64ArrayAttr
61  return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
62 }
63 
64 namespace {
65 struct MemRefPointerLikeModel
66  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
67  MemRefType> {
68  Type getElementType(Type pointer) const {
69  return llvm::cast<MemRefType>(pointer).getElementType();
70  }
71 };
72 
73 struct LLVMPointerPointerLikeModel
74  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75  LLVM::LLVMPointerType> {
76  Type getElementType(Type pointer) const { return Type(); }
77 };
78 } // namespace
79 
80 void OpenMPDialect::initialize() {
81  addOperations<
82 #define GET_OP_LIST
83 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
84  >();
85  addAttributes<
86 #define GET_ATTRDEF_LIST
87 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
88  >();
89  addTypes<
90 #define GET_TYPEDEF_LIST
91 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
92  >();
93 
94  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
95 
96  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
97  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
98  *getContext());
99 
100  // Attach default offload module interface to module op to access
101  // offload functionality through
102  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
103  *getContext());
104 
105  // Attach default declare target interfaces to operations which can be marked
106  // as declare target (Global Operations and Functions/Subroutines in dialects
107  // that Fortran (or other languages that lower to MLIR) translates too
108  mlir::LLVM::GlobalOp::attachInterface<
110  *getContext());
111  mlir::LLVM::LLVMFuncOp::attachInterface<
113  *getContext());
114  mlir::func::FuncOp::attachInterface<
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // Parser and printer for Allocate Clause
120 //===----------------------------------------------------------------------===//
121 
122 /// Parse an allocate clause with allocators and a list of operands with types.
123 ///
124 /// allocate-operand-list :: = allocate-operand |
125 /// allocator-operand `,` allocate-operand-list
126 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
127 /// ssa-id-and-type ::= ssa-id `:` type
128 static ParseResult parseAllocateAndAllocator(
129  OpAsmParser &parser,
131  SmallVectorImpl<Type> &allocateTypes,
133  SmallVectorImpl<Type> &allocatorTypes) {
134 
135  return parser.parseCommaSeparatedList([&]() {
137  Type type;
138  if (parser.parseOperand(operand) || parser.parseColonType(type))
139  return failure();
140  allocatorVars.push_back(operand);
141  allocatorTypes.push_back(type);
142  if (parser.parseArrow())
143  return failure();
144  if (parser.parseOperand(operand) || parser.parseColonType(type))
145  return failure();
146 
147  allocateVars.push_back(operand);
148  allocateTypes.push_back(type);
149  return success();
150  });
151 }
152 
153 /// Print allocate clause
155  OperandRange allocateVars,
156  TypeRange allocateTypes,
157  OperandRange allocatorVars,
158  TypeRange allocatorTypes) {
159  for (unsigned i = 0; i < allocateVars.size(); ++i) {
160  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
161  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
162  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
163  }
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Parser and printer for a clause attribute (StringEnumAttr)
168 //===----------------------------------------------------------------------===//
169 
170 template <typename ClauseAttr>
171 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
172  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
173  StringRef enumStr;
174  SMLoc loc = parser.getCurrentLocation();
175  if (parser.parseKeyword(&enumStr))
176  return failure();
177  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
178  attr = ClauseAttr::get(parser.getContext(), *enumValue);
179  return success();
180  }
181  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
182 }
183 
184 template <typename ClauseAttr>
185 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
186  p << stringifyEnum(attr.getValue());
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // Parser and printer for Linear Clause
191 //===----------------------------------------------------------------------===//
192 
193 /// linear ::= `linear` `(` linear-list `)`
194 /// linear-list := linear-val | linear-val linear-list
195 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
196 static ParseResult parseLinearClause(
197  OpAsmParser &parser,
199  SmallVectorImpl<Type> &linearTypes,
201  return parser.parseCommaSeparatedList([&]() {
203  Type type;
205  if (parser.parseOperand(var) || parser.parseEqual() ||
206  parser.parseOperand(stepVar) || parser.parseColonType(type))
207  return failure();
208 
209  linearVars.push_back(var);
210  linearTypes.push_back(type);
211  linearStepVars.push_back(stepVar);
212  return success();
213  });
214 }
215 
216 /// Print Linear Clause
218  ValueRange linearVars, TypeRange linearTypes,
219  ValueRange linearStepVars) {
220  size_t linearVarsSize = linearVars.size();
221  for (unsigned i = 0; i < linearVarsSize; ++i) {
222  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
223  p << linearVars[i];
224  if (linearStepVars.size() > i)
225  p << " = " << linearStepVars[i];
226  p << " : " << linearVars[i].getType() << separator;
227  }
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Verifier for Nontemporal Clause
232 //===----------------------------------------------------------------------===//
233 
234 static LogicalResult verifyNontemporalClause(Operation *op,
235  OperandRange nontemporalVars) {
236 
237  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
238  DenseSet<Value> nontemporalItems;
239  for (const auto &it : nontemporalVars)
240  if (!nontemporalItems.insert(it).second)
241  return op->emitOpError() << "nontemporal variable used more than once";
242 
243  return success();
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // Parser, verifier and printer for Aligned Clause
248 //===----------------------------------------------------------------------===//
249 static LogicalResult verifyAlignedClause(Operation *op,
250  std::optional<ArrayAttr> alignments,
251  OperandRange alignedVars) {
252  // Check if number of alignment values equals to number of aligned variables
253  if (!alignedVars.empty()) {
254  if (!alignments || alignments->size() != alignedVars.size())
255  return op->emitOpError()
256  << "expected as many alignment values as aligned variables";
257  } else {
258  if (alignments)
259  return op->emitOpError() << "unexpected alignment values attribute";
260  return success();
261  }
262 
263  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
264  DenseSet<Value> alignedItems;
265  for (auto it : alignedVars)
266  if (!alignedItems.insert(it).second)
267  return op->emitOpError() << "aligned variable used more than once";
268 
269  if (!alignments)
270  return success();
271 
272  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
273  for (unsigned i = 0; i < (*alignments).size(); ++i) {
274  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
275  if (intAttr.getValue().sle(0))
276  return op->emitOpError() << "alignment should be greater than 0";
277  } else {
278  return op->emitOpError() << "expected integer alignment";
279  }
280  }
281 
282  return success();
283 }
284 
285 /// aligned ::= `aligned` `(` aligned-list `)`
286 /// aligned-list := aligned-val | aligned-val aligned-list
287 /// aligned-val := ssa-id-and-type `->` alignment
288 static ParseResult
291  SmallVectorImpl<Type> &alignedTypes,
292  ArrayAttr &alignmentsAttr) {
293  SmallVector<Attribute> alignmentVec;
294  if (failed(parser.parseCommaSeparatedList([&]() {
295  if (parser.parseOperand(alignedVars.emplace_back()) ||
296  parser.parseColonType(alignedTypes.emplace_back()) ||
297  parser.parseArrow() ||
298  parser.parseAttribute(alignmentVec.emplace_back())) {
299  return failure();
300  }
301  return success();
302  })))
303  return failure();
304  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
305  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
306  return success();
307 }
308 
309 /// Print Aligned Clause
311  ValueRange alignedVars, TypeRange alignedTypes,
312  std::optional<ArrayAttr> alignments) {
313  for (unsigned i = 0; i < alignedVars.size(); ++i) {
314  if (i != 0)
315  p << ", ";
316  p << alignedVars[i] << " : " << alignedVars[i].getType();
317  p << " -> " << (*alignments)[i];
318  }
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Parser, printer and verifier for Schedule Clause
323 //===----------------------------------------------------------------------===//
324 
325 static ParseResult
327  SmallVectorImpl<SmallString<12>> &modifiers) {
328  if (modifiers.size() > 2)
329  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
330  for (const auto &mod : modifiers) {
331  // Translate the string. If it has no value, then it was not a valid
332  // modifier!
333  auto symbol = symbolizeScheduleModifier(mod);
334  if (!symbol)
335  return parser.emitError(parser.getNameLoc())
336  << " unknown modifier type: " << mod;
337  }
338 
339  // If we have one modifier that is "simd", then stick a "none" modiifer in
340  // index 0.
341  if (modifiers.size() == 1) {
342  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
343  modifiers.push_back(modifiers[0]);
344  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
345  }
346  } else if (modifiers.size() == 2) {
347  // If there are two modifier:
348  // First modifier should not be simd, second one should be simd
349  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
350  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
351  return parser.emitError(parser.getNameLoc())
352  << " incorrect modifier order";
353  }
354  return success();
355 }
356 
357 /// schedule ::= `schedule` `(` sched-list `)`
358 /// sched-list ::= sched-val | sched-val sched-list |
359 /// sched-val `,` sched-modifier
360 /// sched-val ::= sched-with-chunk | sched-wo-chunk
361 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
362 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
363 /// sched-wo-chunk ::= `auto` | `runtime`
364 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
365 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
366 static ParseResult
367 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
368  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
369  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
370  Type &chunkType) {
371  StringRef keyword;
372  if (parser.parseKeyword(&keyword))
373  return failure();
374  std::optional<mlir::omp::ClauseScheduleKind> schedule =
375  symbolizeClauseScheduleKind(keyword);
376  if (!schedule)
377  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
378 
379  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
380  switch (*schedule) {
381  case ClauseScheduleKind::Static:
382  case ClauseScheduleKind::Dynamic:
383  case ClauseScheduleKind::Guided:
384  if (succeeded(parser.parseOptionalEqual())) {
385  chunkSize = OpAsmParser::UnresolvedOperand{};
386  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
387  return failure();
388  } else {
389  chunkSize = std::nullopt;
390  }
391  break;
392  case ClauseScheduleKind::Auto:
394  chunkSize = std::nullopt;
395  }
396 
397  // If there is a comma, we have one or more modifiers..
398  SmallVector<SmallString<12>> modifiers;
399  while (succeeded(parser.parseOptionalComma())) {
400  StringRef mod;
401  if (parser.parseKeyword(&mod))
402  return failure();
403  modifiers.push_back(mod);
404  }
405 
406  if (verifyScheduleModifiers(parser, modifiers))
407  return failure();
408 
409  if (!modifiers.empty()) {
410  SMLoc loc = parser.getCurrentLocation();
411  if (std::optional<ScheduleModifier> mod =
412  symbolizeScheduleModifier(modifiers[0])) {
413  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
414  } else {
415  return parser.emitError(loc, "invalid schedule modifier");
416  }
417  // Only SIMD attribute is allowed here!
418  if (modifiers.size() > 1) {
419  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
420  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
421  }
422  }
423 
424  return success();
425 }
426 
427 /// Print schedule clause
429  ClauseScheduleKindAttr scheduleKind,
430  ScheduleModifierAttr scheduleMod,
431  UnitAttr scheduleSimd, Value scheduleChunk,
432  Type scheduleChunkType) {
433  p << stringifyClauseScheduleKind(scheduleKind.getValue());
434  if (scheduleChunk)
435  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
436  if (scheduleMod)
437  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
438  if (scheduleSimd)
439  p << ", simd";
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // Parser and printer for Order Clause
444 //===----------------------------------------------------------------------===//
445 
446 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
447 // order-modifier ::= reproducible | unconstrained
448 static ParseResult parseOrderClause(OpAsmParser &parser,
449  ClauseOrderKindAttr &order,
450  OrderModifierAttr &orderMod) {
451  StringRef enumStr;
452  SMLoc loc = parser.getCurrentLocation();
453  if (parser.parseKeyword(&enumStr))
454  return failure();
455  if (std::optional<OrderModifier> enumValue =
456  symbolizeOrderModifier(enumStr)) {
457  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
458  if (parser.parseOptionalColon())
459  return failure();
460  loc = parser.getCurrentLocation();
461  if (parser.parseKeyword(&enumStr))
462  return failure();
463  }
464  if (std::optional<ClauseOrderKind> enumValue =
465  symbolizeClauseOrderKind(enumStr)) {
466  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
467  return success();
468  }
469  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
470 }
471 
473  ClauseOrderKindAttr order,
474  OrderModifierAttr orderMod) {
475  if (orderMod)
476  p << stringifyOrderModifier(orderMod.getValue()) << ":";
477  if (order)
478  p << stringifyClauseOrderKind(order.getValue());
479 }
480 
481 template <typename ClauseTypeAttr, typename ClauseType>
482 static ParseResult
483 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
484  std::optional<OpAsmParser::UnresolvedOperand> &operand,
485  Type &operandType,
486  std::optional<ClauseType> (*symbolizeClause)(StringRef),
487  StringRef clauseName) {
488  StringRef enumStr;
489  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
490  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
491  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
492  if (parser.parseComma())
493  return failure();
494  } else {
495  return parser.emitError(parser.getCurrentLocation())
496  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
497  ;
498  }
499  }
500 
502  if (succeeded(parser.parseOperand(var))) {
503  operand = var;
504  } else {
505  return parser.emitError(parser.getCurrentLocation())
506  << "expected " << clauseName << " operand";
507  }
508 
509  if (operand.has_value()) {
510  if (parser.parseColonType(operandType))
511  return failure();
512  }
513 
514  return success();
515 }
516 
517 template <typename ClauseTypeAttr, typename ClauseType>
518 static void
520  ClauseTypeAttr prescriptiveness, Value operand,
521  mlir::Type operandType,
522  StringRef (*stringifyClauseType)(ClauseType)) {
523 
524  if (prescriptiveness)
525  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
526 
527  if (operand)
528  p << operand << ": " << operandType;
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // Parser and printer for grainsize Clause
533 //===----------------------------------------------------------------------===//
534 
535 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
536 static ParseResult
537 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
538  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
539  Type &grainsizeType) {
540  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
541  parser, grainsizeMod, grainsize, grainsizeType,
542  &symbolizeClauseGrainsizeType, "grainsize");
543 }
544 
546  ClauseGrainsizeTypeAttr grainsizeMod,
547  Value grainsize, mlir::Type grainsizeType) {
548  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
549  p, op, grainsizeMod, grainsize, grainsizeType,
550  &stringifyClauseGrainsizeType);
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Parser and printer for num_tasks Clause
555 //===----------------------------------------------------------------------===//
556 
557 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
558 static ParseResult
559 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
560  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
561  Type &numTasksType) {
562  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
563  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
564  "num_tasks");
565 }
566 
568  ClauseNumTasksTypeAttr numTasksMod,
569  Value numTasks, mlir::Type numTasksType) {
570  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
571  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // Parsers for operations including clauses that define entry block arguments.
576 //===----------------------------------------------------------------------===//
577 
578 namespace {
579 struct MapParseArgs {
581  SmallVectorImpl<Type> &types;
583  SmallVectorImpl<Type> &types)
584  : vars(vars), types(types) {}
585 };
586 struct PrivateParseArgs {
589  ArrayAttr &syms;
590  UnitAttr &needsBarrier;
591  DenseI64ArrayAttr *mapIndices;
593  SmallVectorImpl<Type> &types, ArrayAttr &syms,
594  UnitAttr &needsBarrier,
595  DenseI64ArrayAttr *mapIndices = nullptr)
596  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
597  mapIndices(mapIndices) {}
598 };
599 
600 struct ReductionParseArgs {
602  SmallVectorImpl<Type> &types;
603  DenseBoolArrayAttr &byref;
604  ArrayAttr &syms;
605  ReductionModifierAttr *modifier;
606  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
608  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
609  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
610 };
611 
612 struct AllRegionParseArgs {
613  std::optional<MapParseArgs> hasDeviceAddrArgs;
614  std::optional<MapParseArgs> hostEvalArgs;
615  std::optional<ReductionParseArgs> inReductionArgs;
616  std::optional<MapParseArgs> mapArgs;
617  std::optional<PrivateParseArgs> privateArgs;
618  std::optional<ReductionParseArgs> reductionArgs;
619  std::optional<ReductionParseArgs> taskReductionArgs;
620  std::optional<MapParseArgs> useDeviceAddrArgs;
621  std::optional<MapParseArgs> useDevicePtrArgs;
622 };
623 } // namespace
624 
625 static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
626  return "private_barrier";
627 }
628 
629 static ParseResult parseClauseWithRegionArgs(
630  OpAsmParser &parser,
632  SmallVectorImpl<Type> &types,
633  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
634  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
635  DenseBoolArrayAttr *byref = nullptr,
636  ReductionModifierAttr *modifier = nullptr,
637  UnitAttr *needsBarrier = nullptr) {
638  SmallVector<SymbolRefAttr> symbolVec;
639  SmallVector<int64_t> mapIndicesVec;
640  SmallVector<bool> isByRefVec;
641  unsigned regionArgOffset = regionPrivateArgs.size();
642 
643  if (parser.parseLParen())
644  return failure();
645 
646  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
647  StringRef enumStr;
648  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
649  parser.parseComma())
650  return failure();
651  std::optional<ReductionModifier> enumValue =
652  symbolizeReductionModifier(enumStr);
653  if (!enumValue.has_value())
654  return failure();
655  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
656  if (!*modifier)
657  return failure();
658  }
659 
660  if (parser.parseCommaSeparatedList([&]() {
661  if (byref)
662  isByRefVec.push_back(
663  parser.parseOptionalKeyword("byref").succeeded());
664 
665  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
666  return failure();
667 
668  if (parser.parseOperand(operands.emplace_back()) ||
669  parser.parseArrow() ||
670  parser.parseArgument(regionPrivateArgs.emplace_back()))
671  return failure();
672 
673  if (mapIndices) {
674  if (parser.parseOptionalLSquare().succeeded()) {
675  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
676  parser.parseInteger(mapIndicesVec.emplace_back()) ||
677  parser.parseRSquare())
678  return failure();
679  } else {
680  mapIndicesVec.push_back(-1);
681  }
682  }
683 
684  return success();
685  }))
686  return failure();
687 
688  if (parser.parseColon())
689  return failure();
690 
691  if (parser.parseCommaSeparatedList([&]() {
692  if (parser.parseType(types.emplace_back()))
693  return failure();
694 
695  return success();
696  }))
697  return failure();
698 
699  if (operands.size() != types.size())
700  return failure();
701 
702  if (parser.parseRParen())
703  return failure();
704 
705  if (needsBarrier) {
707  .succeeded())
708  *needsBarrier = mlir::UnitAttr::get(parser.getContext());
709  }
710 
711  auto *argsBegin = regionPrivateArgs.begin();
712  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
713  argsBegin + regionArgOffset + types.size());
714  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
715  prv.type = type;
716  }
717 
718  if (symbols) {
719  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
720  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
721  }
722 
723  if (!mapIndicesVec.empty())
724  *mapIndices =
725  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
726 
727  if (byref)
728  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
729 
730  return success();
731 }
732 
733 static ParseResult parseBlockArgClause(
734  OpAsmParser &parser,
736  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
737  if (succeeded(parser.parseOptionalKeyword(keyword))) {
738  if (!mapArgs)
739  return failure();
740 
741  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
742  entryBlockArgs)))
743  return failure();
744  }
745  return success();
746 }
747 
748 static ParseResult parseBlockArgClause(
749  OpAsmParser &parser,
751  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
752  if (succeeded(parser.parseOptionalKeyword(keyword))) {
753  if (!privateArgs)
754  return failure();
755 
757  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
758  &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
759  /*modifier=*/nullptr, &privateArgs->needsBarrier)))
760  return failure();
761  }
762  return success();
763 }
764 
765 static ParseResult parseBlockArgClause(
766  OpAsmParser &parser,
768  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
769  if (succeeded(parser.parseOptionalKeyword(keyword))) {
770  if (!reductionArgs)
771  return failure();
773  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
774  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
775  reductionArgs->modifier)))
776  return failure();
777  }
778  return success();
779 }
780 
781 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
782  AllRegionParseArgs args) {
784 
785  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
786  args.hasDeviceAddrArgs)))
787  return parser.emitError(parser.getCurrentLocation())
788  << "invalid `has_device_addr` format";
789 
790  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
791  args.hostEvalArgs)))
792  return parser.emitError(parser.getCurrentLocation())
793  << "invalid `host_eval` format";
794 
795  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
796  args.inReductionArgs)))
797  return parser.emitError(parser.getCurrentLocation())
798  << "invalid `in_reduction` format";
799 
800  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
801  args.mapArgs)))
802  return parser.emitError(parser.getCurrentLocation())
803  << "invalid `map_entries` format";
804 
805  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
806  args.privateArgs)))
807  return parser.emitError(parser.getCurrentLocation())
808  << "invalid `private` format";
809 
810  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
811  args.reductionArgs)))
812  return parser.emitError(parser.getCurrentLocation())
813  << "invalid `reduction` format";
814 
815  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
816  args.taskReductionArgs)))
817  return parser.emitError(parser.getCurrentLocation())
818  << "invalid `task_reduction` format";
819 
820  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
821  args.useDeviceAddrArgs)))
822  return parser.emitError(parser.getCurrentLocation())
823  << "invalid `use_device_addr` format";
824 
825  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
826  args.useDevicePtrArgs)))
827  return parser.emitError(parser.getCurrentLocation())
828  << "invalid `use_device_addr` format";
829 
830  return parser.parseRegion(region, entryBlockArgs);
831 }
832 
833 // These parseXyz functions correspond to the custom<Xyz> definitions
834 // in the .td file(s).
835 static ParseResult parseTargetOpRegion(
836  OpAsmParser &parser, Region &region,
838  SmallVectorImpl<Type> &hasDeviceAddrTypes,
840  SmallVectorImpl<Type> &hostEvalTypes,
842  SmallVectorImpl<Type> &inReductionTypes,
843  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
845  SmallVectorImpl<Type> &mapTypes,
847  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
848  UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
849  AllRegionParseArgs args;
850  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
851  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
852  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
853  inReductionByref, inReductionSyms);
854  args.mapArgs.emplace(mapVars, mapTypes);
855  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
856  privateNeedsBarrier, &privateMaps);
857  return parseBlockArgRegion(parser, region, args);
858 }
859 
860 static ParseResult parseInReductionPrivateRegion(
861  OpAsmParser &parser, Region &region,
863  SmallVectorImpl<Type> &inReductionTypes,
864  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
866  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
867  UnitAttr &privateNeedsBarrier) {
868  AllRegionParseArgs args;
869  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
870  inReductionByref, inReductionSyms);
871  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
872  privateNeedsBarrier);
873  return parseBlockArgRegion(parser, region, args);
874 }
875 
877  OpAsmParser &parser, Region &region,
879  SmallVectorImpl<Type> &inReductionTypes,
880  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
882  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
883  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
885  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
886  ArrayAttr &reductionSyms) {
887  AllRegionParseArgs args;
888  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
889  inReductionByref, inReductionSyms);
890  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
891  privateNeedsBarrier);
892  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
893  reductionSyms, &reductionMod);
894  return parseBlockArgRegion(parser, region, args);
895 }
896 
897 static ParseResult parsePrivateRegion(
898  OpAsmParser &parser, Region &region,
900  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
901  UnitAttr &privateNeedsBarrier) {
902  AllRegionParseArgs args;
903  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
904  privateNeedsBarrier);
905  return parseBlockArgRegion(parser, region, args);
906 }
907 
908 static ParseResult parsePrivateReductionRegion(
909  OpAsmParser &parser, Region &region,
911  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
912  UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
914  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
915  ArrayAttr &reductionSyms) {
916  AllRegionParseArgs args;
917  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
918  privateNeedsBarrier);
919  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
920  reductionSyms, &reductionMod);
921  return parseBlockArgRegion(parser, region, args);
922 }
923 
924 static ParseResult parseTaskReductionRegion(
925  OpAsmParser &parser, Region &region,
927  SmallVectorImpl<Type> &taskReductionTypes,
928  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
929  AllRegionParseArgs args;
930  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
931  taskReductionByref, taskReductionSyms);
932  return parseBlockArgRegion(parser, region, args);
933 }
934 
936  OpAsmParser &parser, Region &region,
938  SmallVectorImpl<Type> &useDeviceAddrTypes,
940  SmallVectorImpl<Type> &useDevicePtrTypes) {
941  AllRegionParseArgs args;
942  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
943  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
944  return parseBlockArgRegion(parser, region, args);
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // Printers for operations including clauses that define entry block arguments.
949 //===----------------------------------------------------------------------===//
950 
951 namespace {
952 struct MapPrintArgs {
953  ValueRange vars;
954  TypeRange types;
955  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
956 };
957 struct PrivatePrintArgs {
958  ValueRange vars;
959  TypeRange types;
960  ArrayAttr syms;
961  UnitAttr needsBarrier;
962  DenseI64ArrayAttr mapIndices;
963  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
964  UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
965  : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
966  mapIndices(mapIndices) {}
967 };
968 struct ReductionPrintArgs {
969  ValueRange vars;
970  TypeRange types;
971  DenseBoolArrayAttr byref;
972  ArrayAttr syms;
973  ReductionModifierAttr modifier;
974  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
975  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
976  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
977 };
978 struct AllRegionPrintArgs {
979  std::optional<MapPrintArgs> hasDeviceAddrArgs;
980  std::optional<MapPrintArgs> hostEvalArgs;
981  std::optional<ReductionPrintArgs> inReductionArgs;
982  std::optional<MapPrintArgs> mapArgs;
983  std::optional<PrivatePrintArgs> privateArgs;
984  std::optional<ReductionPrintArgs> reductionArgs;
985  std::optional<ReductionPrintArgs> taskReductionArgs;
986  std::optional<MapPrintArgs> useDeviceAddrArgs;
987  std::optional<MapPrintArgs> useDevicePtrArgs;
988 };
989 } // namespace
990 
992  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
993  ValueRange argsSubrange, ValueRange operands, TypeRange types,
994  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
995  DenseBoolArrayAttr byref = nullptr,
996  ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
997  if (argsSubrange.empty())
998  return;
999 
1000  p << clauseName << "(";
1001 
1002  if (modifier)
1003  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1004 
1005  if (!symbols) {
1006  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1007  symbols = ArrayAttr::get(ctx, values);
1008  }
1009 
1010  if (!mapIndices) {
1011  llvm::SmallVector<int64_t> values(operands.size(), -1);
1012  mapIndices = DenseI64ArrayAttr::get(ctx, values);
1013  }
1014 
1015  if (!byref) {
1016  mlir::SmallVector<bool> values(operands.size(), false);
1017  byref = DenseBoolArrayAttr::get(ctx, values);
1018  }
1019 
1020  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1021  mapIndices.asArrayRef(),
1022  byref.asArrayRef()),
1023  p, [&p](auto t) {
1024  auto [op, arg, sym, map, isByRef] = t;
1025  if (isByRef)
1026  p << "byref ";
1027  if (sym)
1028  p << sym << " ";
1029 
1030  p << op << " -> " << arg;
1031 
1032  if (map != -1)
1033  p << " [map_idx=" << map << "]";
1034  });
1035  p << " : ";
1036  llvm::interleaveComma(types, p);
1037  p << ") ";
1038 
1039  if (needsBarrier)
1040  p << getPrivateNeedsBarrierSpelling() << " ";
1041 }
1042 
1044  StringRef clauseName, ValueRange argsSubrange,
1045  std::optional<MapPrintArgs> mapArgs) {
1046  if (mapArgs)
1047  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1048  mapArgs->types);
1049 }
1050 
1052  StringRef clauseName, ValueRange argsSubrange,
1053  std::optional<PrivatePrintArgs> privateArgs) {
1054  if (privateArgs)
1056  p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1057  privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1058  /*modifier=*/nullptr, privateArgs->needsBarrier);
1059 }
1060 
1061 static void
1062 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1063  ValueRange argsSubrange,
1064  std::optional<ReductionPrintArgs> reductionArgs) {
1065  if (reductionArgs)
1066  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1067  reductionArgs->vars, reductionArgs->types,
1068  reductionArgs->syms, /*mapIndices=*/nullptr,
1069  reductionArgs->byref, reductionArgs->modifier);
1070 }
1071 
1072 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1073  const AllRegionPrintArgs &args) {
1074  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1075  MLIRContext *ctx = op->getContext();
1076 
1077  printBlockArgClause(p, ctx, "has_device_addr",
1078  iface.getHasDeviceAddrBlockArgs(),
1079  args.hasDeviceAddrArgs);
1080  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1081  args.hostEvalArgs);
1082  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1083  args.inReductionArgs);
1084  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1085  args.mapArgs);
1086  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1087  args.privateArgs);
1088  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1089  args.reductionArgs);
1090  printBlockArgClause(p, ctx, "task_reduction",
1091  iface.getTaskReductionBlockArgs(),
1092  args.taskReductionArgs);
1093  printBlockArgClause(p, ctx, "use_device_addr",
1094  iface.getUseDeviceAddrBlockArgs(),
1095  args.useDeviceAddrArgs);
1096  printBlockArgClause(p, ctx, "use_device_ptr",
1097  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1098 
1099  p.printRegion(region, /*printEntryBlockArgs=*/false);
1100 }
1101 
1102 // These parseXyz functions correspond to the custom<Xyz> definitions
1103 // in the .td file(s).
1105  OpAsmPrinter &p, Operation *op, Region &region,
1106  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1107  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1108  ValueRange inReductionVars, TypeRange inReductionTypes,
1109  DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1110  ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1111  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1112  DenseI64ArrayAttr privateMaps) {
1113  AllRegionPrintArgs args;
1114  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1115  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1116  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1117  inReductionByref, inReductionSyms);
1118  args.mapArgs.emplace(mapVars, mapTypes);
1119  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1120  privateNeedsBarrier, privateMaps);
1121  printBlockArgRegion(p, op, region, args);
1122 }
1123 
1125  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1126  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1127  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1128  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1129  AllRegionPrintArgs args;
1130  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1131  inReductionByref, inReductionSyms);
1132  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1133  privateNeedsBarrier,
1134  /*mapIndices=*/nullptr);
1135  printBlockArgRegion(p, op, region, args);
1136 }
1137 
1139  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1140  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1141  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1142  ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1143  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1144  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1145  ArrayAttr reductionSyms) {
1146  AllRegionPrintArgs args;
1147  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1148  inReductionByref, inReductionSyms);
1149  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1150  privateNeedsBarrier,
1151  /*mapIndices=*/nullptr);
1152  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1153  reductionSyms, reductionMod);
1154  printBlockArgRegion(p, op, region, args);
1155 }
1156 
1157 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1158  ValueRange privateVars, TypeRange privateTypes,
1159  ArrayAttr privateSyms,
1160  UnitAttr privateNeedsBarrier) {
1161  AllRegionPrintArgs args;
1162  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1163  privateNeedsBarrier,
1164  /*mapIndices=*/nullptr);
1165  printBlockArgRegion(p, op, region, args);
1166 }
1167 
1169  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1170  TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1171  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1172  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1173  ArrayAttr reductionSyms) {
1174  AllRegionPrintArgs args;
1175  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1176  privateNeedsBarrier,
1177  /*mapIndices=*/nullptr);
1178  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1179  reductionSyms, reductionMod);
1180  printBlockArgRegion(p, op, region, args);
1181 }
1182 
1184  Region &region,
1185  ValueRange taskReductionVars,
1186  TypeRange taskReductionTypes,
1187  DenseBoolArrayAttr taskReductionByref,
1188  ArrayAttr taskReductionSyms) {
1189  AllRegionPrintArgs args;
1190  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1191  taskReductionByref, taskReductionSyms);
1192  printBlockArgRegion(p, op, region, args);
1193 }
1194 
1196  Region &region,
1197  ValueRange useDeviceAddrVars,
1198  TypeRange useDeviceAddrTypes,
1199  ValueRange useDevicePtrVars,
1200  TypeRange useDevicePtrTypes) {
1201  AllRegionPrintArgs args;
1202  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1203  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1204  printBlockArgRegion(p, op, region, args);
1205 }
1206 
1207 /// Verifies Reduction Clause
1208 static LogicalResult
1209 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1210  OperandRange reductionVars,
1211  std::optional<ArrayRef<bool>> reductionByref) {
1212  if (!reductionVars.empty()) {
1213  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1214  return op->emitOpError()
1215  << "expected as many reduction symbol references "
1216  "as reduction variables";
1217  if (reductionByref && reductionByref->size() != reductionVars.size())
1218  return op->emitError() << "expected as many reduction variable by "
1219  "reference attributes as reduction variables";
1220  } else {
1221  if (reductionSyms)
1222  return op->emitOpError() << "unexpected reduction symbol references";
1223  return success();
1224  }
1225 
1226  // TODO: The followings should be done in
1227  // SymbolUserOpInterface::verifySymbolUses.
1228  DenseSet<Value> accumulators;
1229  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1230  Value accum = std::get<0>(args);
1231 
1232  if (!accumulators.insert(accum).second)
1233  return op->emitOpError() << "accumulator variable used more than once";
1234 
1235  Type varType = accum.getType();
1236  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1237  auto decl =
1238  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1239  if (!decl)
1240  return op->emitOpError() << "expected symbol reference " << symbolRef
1241  << " to point to a reduction declaration";
1242 
1243  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1244  return op->emitOpError()
1245  << "expected accumulator (" << varType
1246  << ") to be the same type as reduction declaration ("
1247  << decl.getAccumulatorType() << ")";
1248  }
1249 
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // Parser, printer and verifier for Copyprivate
1255 //===----------------------------------------------------------------------===//
1256 
1257 /// copyprivate-entry-list ::= copyprivate-entry
1258 /// | copyprivate-entry-list `,` copyprivate-entry
1259 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1260 static ParseResult parseCopyprivate(
1261  OpAsmParser &parser,
1263  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1265  if (failed(parser.parseCommaSeparatedList([&]() {
1266  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1267  parser.parseArrow() ||
1268  parser.parseAttribute(symsVec.emplace_back()) ||
1269  parser.parseColonType(copyprivateTypes.emplace_back()))
1270  return failure();
1271  return success();
1272  })))
1273  return failure();
1274  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1275  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1276  return success();
1277 }
1278 
1279 /// Print Copyprivate clause
1281  OperandRange copyprivateVars,
1282  TypeRange copyprivateTypes,
1283  std::optional<ArrayAttr> copyprivateSyms) {
1284  if (!copyprivateSyms.has_value())
1285  return;
1286  llvm::interleaveComma(
1287  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1288  [&](const auto &args) {
1289  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1290  << std::get<2>(args);
1291  });
1292 }
1293 
1294 /// Verifies CopyPrivate Clause
1295 static LogicalResult
1297  std::optional<ArrayAttr> copyprivateSyms) {
1298  size_t copyprivateSymsSize =
1299  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1300  if (copyprivateSymsSize != copyprivateVars.size())
1301  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1302  << copyprivateVars.size()
1303  << ") and functions (= " << copyprivateSymsSize
1304  << "), both must be equal";
1305  if (!copyprivateSyms.has_value())
1306  return success();
1307 
1308  for (auto copyprivateVarAndSym :
1309  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1310  auto symbolRef =
1311  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1312  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1313  funcOp;
1314  if (mlir::func::FuncOp mlirFuncOp =
1315  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1316  symbolRef))
1317  funcOp = mlirFuncOp;
1318  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1319  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1320  op, symbolRef))
1321  funcOp = llvmFuncOp;
1322 
1323  auto getNumArguments = [&] {
1324  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1325  };
1326 
1327  auto getArgumentType = [&](unsigned i) {
1328  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1329  *funcOp);
1330  };
1331 
1332  if (!funcOp)
1333  return op->emitOpError() << "expected symbol reference " << symbolRef
1334  << " to point to a copy function";
1335 
1336  if (getNumArguments() != 2)
1337  return op->emitOpError()
1338  << "expected copy function " << symbolRef << " to have 2 operands";
1339 
1340  Type argTy = getArgumentType(0);
1341  if (argTy != getArgumentType(1))
1342  return op->emitOpError() << "expected copy function " << symbolRef
1343  << " arguments to have the same type";
1344 
1345  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1346  if (argTy != varType)
1347  return op->emitOpError()
1348  << "expected copy function arguments' type (" << argTy
1349  << ") to be the same as copyprivate variable's type (" << varType
1350  << ")";
1351  }
1352 
1353  return success();
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // Parser, printer and verifier for DependVarList
1358 //===----------------------------------------------------------------------===//
1359 
1360 /// depend-entry-list ::= depend-entry
1361 /// | depend-entry-list `,` depend-entry
1362 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1363 static ParseResult
1366  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1368  if (failed(parser.parseCommaSeparatedList([&]() {
1369  StringRef keyword;
1370  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1371  parser.parseOperand(dependVars.emplace_back()) ||
1372  parser.parseColonType(dependTypes.emplace_back()))
1373  return failure();
1374  if (std::optional<ClauseTaskDepend> keywordDepend =
1375  (symbolizeClauseTaskDepend(keyword)))
1376  kindsVec.emplace_back(
1377  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1378  else
1379  return failure();
1380  return success();
1381  })))
1382  return failure();
1383  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1384  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1385  return success();
1386 }
1387 
1388 /// Print Depend clause
1390  OperandRange dependVars, TypeRange dependTypes,
1391  std::optional<ArrayAttr> dependKinds) {
1392 
1393  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1394  if (i != 0)
1395  p << ", ";
1396  p << stringifyClauseTaskDepend(
1397  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1398  .getValue())
1399  << " -> " << dependVars[i] << " : " << dependTypes[i];
1400  }
1401 }
1402 
1403 /// Verifies Depend clause
1404 static LogicalResult verifyDependVarList(Operation *op,
1405  std::optional<ArrayAttr> dependKinds,
1406  OperandRange dependVars) {
1407  if (!dependVars.empty()) {
1408  if (!dependKinds || dependKinds->size() != dependVars.size())
1409  return op->emitOpError() << "expected as many depend values"
1410  " as depend variables";
1411  } else {
1412  if (dependKinds && !dependKinds->empty())
1413  return op->emitOpError() << "unexpected depend values";
1414  return success();
1415  }
1416 
1417  return success();
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1422 //===----------------------------------------------------------------------===//
1423 
1424 /// Parses a Synchronization Hint clause. The value of hint is an integer
1425 /// which is a combination of different hints from `omp_sync_hint_t`.
1426 ///
1427 /// hint-clause = `hint` `(` hint-value `)`
1428 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1429  IntegerAttr &hintAttr) {
1430  StringRef hintKeyword;
1431  int64_t hint = 0;
1432  if (succeeded(parser.parseOptionalKeyword("none"))) {
1433  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1434  return success();
1435  }
1436  auto parseKeyword = [&]() -> ParseResult {
1437  if (failed(parser.parseKeyword(&hintKeyword)))
1438  return failure();
1439  if (hintKeyword == "uncontended")
1440  hint |= 1;
1441  else if (hintKeyword == "contended")
1442  hint |= 2;
1443  else if (hintKeyword == "nonspeculative")
1444  hint |= 4;
1445  else if (hintKeyword == "speculative")
1446  hint |= 8;
1447  else
1448  return parser.emitError(parser.getCurrentLocation())
1449  << hintKeyword << " is not a valid hint";
1450  return success();
1451  };
1452  if (parser.parseCommaSeparatedList(parseKeyword))
1453  return failure();
1454  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1455  return success();
1456 }
1457 
1458 /// Prints a Synchronization Hint clause
1460  IntegerAttr hintAttr) {
1461  int64_t hint = hintAttr.getInt();
1462 
1463  if (hint == 0) {
1464  p << "none";
1465  return;
1466  }
1467 
1468  // Helper function to get n-th bit from the right end of `value`
1469  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1470 
1471  bool uncontended = bitn(hint, 0);
1472  bool contended = bitn(hint, 1);
1473  bool nonspeculative = bitn(hint, 2);
1474  bool speculative = bitn(hint, 3);
1475 
1476  SmallVector<StringRef> hints;
1477  if (uncontended)
1478  hints.push_back("uncontended");
1479  if (contended)
1480  hints.push_back("contended");
1481  if (nonspeculative)
1482  hints.push_back("nonspeculative");
1483  if (speculative)
1484  hints.push_back("speculative");
1485 
1486  llvm::interleaveComma(hints, p);
1487 }
1488 
1489 /// Verifies a synchronization hint clause
1490 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1491 
1492  // Helper function to get n-th bit from the right end of `value`
1493  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1494 
1495  bool uncontended = bitn(hint, 0);
1496  bool contended = bitn(hint, 1);
1497  bool nonspeculative = bitn(hint, 2);
1498  bool speculative = bitn(hint, 3);
1499 
1500  if (uncontended && contended)
1501  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1502  "omp_sync_hint_contended cannot be combined";
1503  if (nonspeculative && speculative)
1504  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1505  "omp_sync_hint_speculative cannot be combined.";
1506  return success();
1507 }
1508 
1509 //===----------------------------------------------------------------------===//
1510 // Parser, printer and verifier for Target
1511 //===----------------------------------------------------------------------===//
1512 
1513 // Helper function to get bitwise AND of `value` and 'flag'
1514 uint64_t mapTypeToBitFlag(uint64_t value,
1515  llvm::omp::OpenMPOffloadMappingFlags flag) {
1516  return value & llvm::to_underlying(flag);
1517 }
1518 
1519 /// Parses a map_entries map type from a string format back into its numeric
1520 /// value.
1521 ///
1522 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1523 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1524 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1525  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1526  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1527 
1528  // This simply verifies the correct keyword is read in, the
1529  // keyword itself is stored inside of the operation
1530  auto parseTypeAndMod = [&]() -> ParseResult {
1531  StringRef mapTypeMod;
1532  if (parser.parseKeyword(&mapTypeMod))
1533  return failure();
1534 
1535  if (mapTypeMod == "always")
1536  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1537 
1538  if (mapTypeMod == "implicit")
1539  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1540 
1541  if (mapTypeMod == "ompx_hold")
1542  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1543 
1544  if (mapTypeMod == "close")
1545  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1546 
1547  if (mapTypeMod == "present")
1548  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1549 
1550  if (mapTypeMod == "to")
1551  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1552 
1553  if (mapTypeMod == "from")
1554  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1555 
1556  if (mapTypeMod == "tofrom")
1557  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1558  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1559 
1560  if (mapTypeMod == "delete")
1561  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1562 
1563  if (mapTypeMod == "return_param")
1564  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1565 
1566  return success();
1567  };
1568 
1569  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1570  return failure();
1571 
1572  mapType = parser.getBuilder().getIntegerAttr(
1573  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1574  llvm::to_underlying(mapTypeBits));
1575 
1576  return success();
1577 }
1578 
1579 /// Prints a map_entries map type from its numeric value out into its string
1580 /// format.
1582  IntegerAttr mapType) {
1583  uint64_t mapTypeBits = mapType.getUInt();
1584 
1585  bool emitAllocRelease = true;
1587 
1588  // handling of always, close, present placed at the beginning of the string
1589  // to aid readability
1590  if (mapTypeToBitFlag(mapTypeBits,
1591  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1592  mapTypeStrs.push_back("always");
1593  if (mapTypeToBitFlag(mapTypeBits,
1594  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1595  mapTypeStrs.push_back("implicit");
1596  if (mapTypeToBitFlag(mapTypeBits,
1597  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1598  mapTypeStrs.push_back("ompx_hold");
1599  if (mapTypeToBitFlag(mapTypeBits,
1600  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1601  mapTypeStrs.push_back("close");
1602  if (mapTypeToBitFlag(mapTypeBits,
1603  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1604  mapTypeStrs.push_back("present");
1605 
1606  // special handling of to/from/tofrom/delete and release/alloc, release +
1607  // alloc are the abscense of one of the other flags, whereas tofrom requires
1608  // both the to and from flag to be set.
1609  bool to = mapTypeToBitFlag(mapTypeBits,
1610  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1611  bool from = mapTypeToBitFlag(
1612  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1613  if (to && from) {
1614  emitAllocRelease = false;
1615  mapTypeStrs.push_back("tofrom");
1616  } else if (from) {
1617  emitAllocRelease = false;
1618  mapTypeStrs.push_back("from");
1619  } else if (to) {
1620  emitAllocRelease = false;
1621  mapTypeStrs.push_back("to");
1622  }
1623  if (mapTypeToBitFlag(mapTypeBits,
1624  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1625  emitAllocRelease = false;
1626  mapTypeStrs.push_back("delete");
1627  }
1628  if (mapTypeToBitFlag(
1629  mapTypeBits,
1630  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1631  emitAllocRelease = false;
1632  mapTypeStrs.push_back("return_param");
1633  }
1634  if (emitAllocRelease)
1635  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1636 
1637  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1638  p << mapTypeStrs[i];
1639  if (i + 1 < mapTypeStrs.size()) {
1640  p << ", ";
1641  }
1642  }
1643 }
1644 
1645 static ParseResult parseMembersIndex(OpAsmParser &parser,
1646  ArrayAttr &membersIdx) {
1647  SmallVector<Attribute> values, memberIdxs;
1648 
1649  auto parseIndices = [&]() -> ParseResult {
1650  int64_t value;
1651  if (parser.parseInteger(value))
1652  return failure();
1653  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1654  APInt(64, value, /*isSigned=*/false)));
1655  return success();
1656  };
1657 
1658  do {
1659  if (failed(parser.parseLSquare()))
1660  return failure();
1661 
1662  if (parser.parseCommaSeparatedList(parseIndices))
1663  return failure();
1664 
1665  if (failed(parser.parseRSquare()))
1666  return failure();
1667 
1668  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1669  values.clear();
1670  } while (succeeded(parser.parseOptionalComma()));
1671 
1672  if (!memberIdxs.empty())
1673  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1674 
1675  return success();
1676 }
1677 
1678 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1679  ArrayAttr membersIdx) {
1680  if (!membersIdx)
1681  return;
1682 
1683  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1684  p << "[";
1685  auto memberIdx = cast<ArrayAttr>(v);
1686  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1687  p << cast<IntegerAttr>(v2).getInt();
1688  });
1689  p << "]";
1690  });
1691 }
1692 
1694  VariableCaptureKindAttr mapCaptureType) {
1695  std::string typeCapStr;
1696  llvm::raw_string_ostream typeCap(typeCapStr);
1697  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1698  typeCap << "ByRef";
1699  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1700  typeCap << "ByCopy";
1701  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1702  typeCap << "VLAType";
1703  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1704  typeCap << "This";
1705  p << typeCapStr;
1706 }
1707 
1708 static ParseResult parseCaptureType(OpAsmParser &parser,
1709  VariableCaptureKindAttr &mapCaptureType) {
1710  StringRef mapCaptureKey;
1711  if (parser.parseKeyword(&mapCaptureKey))
1712  return failure();
1713 
1714  if (mapCaptureKey == "This")
1715  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1716  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1717  if (mapCaptureKey == "ByRef")
1718  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1719  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1720  if (mapCaptureKey == "ByCopy")
1721  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1722  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1723  if (mapCaptureKey == "VLAType")
1724  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1725  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1726 
1727  return success();
1728 }
1729 
1730 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1733 
1734  for (auto mapOp : mapVars) {
1735  if (!mapOp.getDefiningOp())
1736  return emitError(op->getLoc(), "missing map operation");
1737 
1738  if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1739  uint64_t mapTypeBits = mapInfoOp.getMapType();
1740 
1741  bool to = mapTypeToBitFlag(
1742  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1743  bool from = mapTypeToBitFlag(
1744  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1745  bool del = mapTypeToBitFlag(
1746  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1747 
1748  bool always = mapTypeToBitFlag(
1749  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1750  bool close = mapTypeToBitFlag(
1751  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1752  bool implicit = mapTypeToBitFlag(
1753  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1754 
1755  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1756  return emitError(op->getLoc(),
1757  "to, from, tofrom and alloc map types are permitted");
1758 
1759  if (isa<TargetEnterDataOp>(op) && (from || del))
1760  return emitError(op->getLoc(), "to and alloc map types are permitted");
1761 
1762  if (isa<TargetExitDataOp>(op) && to)
1763  return emitError(op->getLoc(),
1764  "from, release and delete map types are permitted");
1765 
1766  if (isa<TargetUpdateOp>(op)) {
1767  if (del) {
1768  return emitError(op->getLoc(),
1769  "at least one of to or from map types must be "
1770  "specified, other map types are not permitted");
1771  }
1772 
1773  if (!to && !from) {
1774  return emitError(op->getLoc(),
1775  "at least one of to or from map types must be "
1776  "specified, other map types are not permitted");
1777  }
1778 
1779  auto updateVar = mapInfoOp.getVarPtr();
1780 
1781  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1782  (from && updateToVars.contains(updateVar))) {
1783  return emitError(
1784  op->getLoc(),
1785  "either to or from map types can be specified, not both");
1786  }
1787 
1788  if (always || close || implicit) {
1789  return emitError(
1790  op->getLoc(),
1791  "present, mapper and iterator map type modifiers are permitted");
1792  }
1793 
1794  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1795  }
1796  } else if (!isa<DeclareMapperInfoOp>(op)) {
1797  return emitError(op->getLoc(),
1798  "map argument is not a map entry operation");
1799  }
1800  }
1801 
1802  return success();
1803 }
1804 
1805 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1806  std::optional<DenseI64ArrayAttr> privateMapIndices =
1807  targetOp.getPrivateMapsAttr();
1808 
1809  // None of the private operands are mapped.
1810  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1811  return success();
1812 
1813  OperandRange privateVars = targetOp.getPrivateVars();
1814 
1815  if (privateMapIndices.value().size() !=
1816  static_cast<int64_t>(privateVars.size()))
1817  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1818  "`private_maps` attribute mismatch");
1819 
1820  return success();
1821 }
1822 
1823 //===----------------------------------------------------------------------===//
1824 // MapInfoOp
1825 //===----------------------------------------------------------------------===//
1826 
1827 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1828  StringRef clauseName,
1829  OperandRange vars) {
1830  for (Value var : vars)
1831  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1832  return op->emitOpError()
1833  << "'" << clauseName
1834  << "' arguments must be defined by 'omp.map.info' ops";
1835  return success();
1836 }
1837 
1838 LogicalResult MapInfoOp::verify() {
1839  if (getMapperId() &&
1840  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1841  *this, getMapperIdAttr())) {
1842  return emitError("invalid mapper id");
1843  }
1844 
1845  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1846  return failure();
1847 
1848  return success();
1849 }
1850 
1851 //===----------------------------------------------------------------------===//
1852 // TargetDataOp
1853 //===----------------------------------------------------------------------===//
1854 
1855 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1856  const TargetDataOperands &clauses) {
1857  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1858  clauses.mapVars, clauses.useDeviceAddrVars,
1859  clauses.useDevicePtrVars);
1860 }
1861 
1862 LogicalResult TargetDataOp::verify() {
1863  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1864  getUseDeviceAddrVars().empty()) {
1865  return ::emitError(this->getLoc(),
1866  "At least one of map, use_device_ptr_vars, or "
1867  "use_device_addr_vars operand must be present");
1868  }
1869 
1870  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1871  getUseDevicePtrVars())))
1872  return failure();
1873 
1874  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1875  getUseDeviceAddrVars())))
1876  return failure();
1877 
1878  return verifyMapClause(*this, getMapVars());
1879 }
1880 
1881 //===----------------------------------------------------------------------===//
1882 // TargetEnterDataOp
1883 //===----------------------------------------------------------------------===//
1884 
1885 void TargetEnterDataOp::build(
1886  OpBuilder &builder, OperationState &state,
1887  const TargetEnterExitUpdateDataOperands &clauses) {
1888  MLIRContext *ctx = builder.getContext();
1889  TargetEnterDataOp::build(builder, state,
1890  makeArrayAttr(ctx, clauses.dependKinds),
1891  clauses.dependVars, clauses.device, clauses.ifExpr,
1892  clauses.mapVars, clauses.nowait);
1893 }
1894 
1895 LogicalResult TargetEnterDataOp::verify() {
1896  LogicalResult verifyDependVars =
1897  verifyDependVarList(*this, getDependKinds(), getDependVars());
1898  return failed(verifyDependVars) ? verifyDependVars
1899  : verifyMapClause(*this, getMapVars());
1900 }
1901 
1902 //===----------------------------------------------------------------------===//
1903 // TargetExitDataOp
1904 //===----------------------------------------------------------------------===//
1905 
1906 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1907  const TargetEnterExitUpdateDataOperands &clauses) {
1908  MLIRContext *ctx = builder.getContext();
1909  TargetExitDataOp::build(builder, state,
1910  makeArrayAttr(ctx, clauses.dependKinds),
1911  clauses.dependVars, clauses.device, clauses.ifExpr,
1912  clauses.mapVars, clauses.nowait);
1913 }
1914 
1915 LogicalResult TargetExitDataOp::verify() {
1916  LogicalResult verifyDependVars =
1917  verifyDependVarList(*this, getDependKinds(), getDependVars());
1918  return failed(verifyDependVars) ? verifyDependVars
1919  : verifyMapClause(*this, getMapVars());
1920 }
1921 
1922 //===----------------------------------------------------------------------===//
1923 // TargetUpdateOp
1924 //===----------------------------------------------------------------------===//
1925 
1926 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1927  const TargetEnterExitUpdateDataOperands &clauses) {
1928  MLIRContext *ctx = builder.getContext();
1929  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1930  clauses.dependVars, clauses.device, clauses.ifExpr,
1931  clauses.mapVars, clauses.nowait);
1932 }
1933 
1934 LogicalResult TargetUpdateOp::verify() {
1935  LogicalResult verifyDependVars =
1936  verifyDependVarList(*this, getDependKinds(), getDependVars());
1937  return failed(verifyDependVars) ? verifyDependVars
1938  : verifyMapClause(*this, getMapVars());
1939 }
1940 
1941 //===----------------------------------------------------------------------===//
1942 // TargetOp
1943 //===----------------------------------------------------------------------===//
1944 
1945 void TargetOp::build(OpBuilder &builder, OperationState &state,
1946  const TargetOperands &clauses) {
1947  MLIRContext *ctx = builder.getContext();
1948  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1949  // inReductionByref, inReductionSyms.
1950  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1951  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1952  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1953  clauses.hostEvalVars, clauses.ifExpr,
1954  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1955  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1956  clauses.mapVars, clauses.nowait, clauses.privateVars,
1957  makeArrayAttr(ctx, clauses.privateSyms),
1958  clauses.privateNeedsBarrier, clauses.threadLimit,
1959  /*private_maps=*/nullptr);
1960 }
1961 
1962 LogicalResult TargetOp::verify() {
1963  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1964  return failure();
1965 
1966  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1967  getHasDeviceAddrVars())))
1968  return failure();
1969 
1970  if (failed(verifyMapClause(*this, getMapVars())))
1971  return failure();
1972 
1973  return verifyPrivateVarsMapping(*this);
1974 }
1975 
1976 LogicalResult TargetOp::verifyRegions() {
1977  auto teamsOps = getOps<TeamsOp>();
1978  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1979  return emitError("target containing multiple 'omp.teams' nested ops");
1980 
1981  // Check that host_eval values are only used in legal ways.
1982  Operation *capturedOp = getInnermostCapturedOmpOp();
1983  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1984  for (Value hostEvalArg :
1985  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1986  for (Operation *user : hostEvalArg.getUsers()) {
1987  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1988  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1989  teamsOp.getNumTeamsUpper(),
1990  teamsOp.getThreadLimit()},
1991  hostEvalArg))
1992  continue;
1993 
1994  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1995  "and 'thread_limit' in 'omp.teams'";
1996  }
1997  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1998  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1999  parallelOp->isAncestor(capturedOp) &&
2000  hostEvalArg == parallelOp.getNumThreads())
2001  continue;
2002 
2003  return emitOpError()
2004  << "host_eval argument only legal as 'num_threads' in "
2005  "'omp.parallel' when representing target SPMD";
2006  }
2007  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2008  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2009  loopNestOp.getOperation() == capturedOp &&
2010  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2011  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2012  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2013  continue;
2014 
2015  return emitOpError() << "host_eval argument only legal as loop bounds "
2016  "and steps in 'omp.loop_nest' when trip count "
2017  "must be evaluated in the host";
2018  }
2019 
2020  return emitOpError() << "host_eval argument illegal use in '"
2021  << user->getName() << "' operation";
2022  }
2023  }
2024  return success();
2025 }
2026 
2027 static Operation *
2028 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2029  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2030  assert(rootOp && "expected valid operation");
2031 
2032  Dialect *ompDialect = rootOp->getDialect();
2033  Operation *capturedOp = nullptr;
2034  DominanceInfo domInfo;
2035 
2036  // Process in pre-order to check operations from outermost to innermost,
2037  // ensuring we only enter the region of an operation if it meets the criteria
2038  // for being captured. We stop the exploration of nested operations as soon as
2039  // we process a region holding no operations to be captured.
2040  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2041  if (op == rootOp)
2042  return WalkResult::advance();
2043 
2044  // Ignore operations of other dialects or omp operations with no regions,
2045  // because these will only be checked if they are siblings of an omp
2046  // operation that can potentially be captured.
2047  bool isOmpDialect = op->getDialect() == ompDialect;
2048  bool hasRegions = op->getNumRegions() > 0;
2049  if (!isOmpDialect || !hasRegions)
2050  return WalkResult::skip();
2051 
2052  // This operation cannot be captured if it can be executed more than once
2053  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2054  // be executed before all exits of the region (i.e. it doesn't dominate all
2055  // blocks with no successors reachable from the entry block).
2056  if (checkSingleMandatoryExec) {
2057  Region *parentRegion = op->getParentRegion();
2058  Block *parentBlock = op->getBlock();
2059 
2060  for (Block *successor : parentBlock->getSuccessors())
2061  if (successor->isReachable(parentBlock))
2062  return WalkResult::interrupt();
2063 
2064  for (Block &block : *parentRegion)
2065  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2066  !domInfo.dominates(parentBlock, &block))
2067  return WalkResult::interrupt();
2068  }
2069 
2070  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2071  // into nested operations.
2072  for (Operation &sibling : op->getParentRegion()->getOps())
2073  if (&sibling != op && !siblingAllowedFn(&sibling))
2074  return WalkResult::interrupt();
2075 
2076  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2077  // Otherwise, process the contents of this operation.
2078  capturedOp = op;
2079  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2080  : WalkResult::advance();
2081  });
2082 
2083  return capturedOp;
2084 }
2085 
2086 Operation *TargetOp::getInnermostCapturedOmpOp() {
2087  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2088 
2089  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2090  // effects, but don't include a memory write effect.
2091  return findCapturedOmpOp(
2092  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2093  if (!sibling)
2094  return false;
2095 
2096  if (ompDialect == sibling->getDialect())
2097  return sibling->hasTrait<OpTrait::IsTerminator>();
2098 
2099  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2101  effects;
2102  memOp.getEffects(effects);
2103  return !llvm::any_of(
2104  effects, [&](MemoryEffects::EffectInstance &effect) {
2105  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2106  isa<SideEffects::AutomaticAllocationScopeResource>(
2107  effect.getResource());
2108  });
2109  }
2110  return true;
2111  });
2112 }
2113 
2114 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2115  // A non-null captured op is only valid if it resides inside of a TargetOp
2116  // and is the result of calling getInnermostCapturedOmpOp() on it.
2117  TargetOp targetOp =
2118  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2119  assert((!capturedOp ||
2120  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2121  "unexpected captured op");
2122 
2123  // If it's not capturing a loop, it's a default target region.
2124  if (!isa_and_present<LoopNestOp>(capturedOp))
2125  return TargetRegionFlags::generic;
2126 
2127  // Get the innermost non-simd loop wrapper.
2128  SmallVector<LoopWrapperInterface> loopWrappers;
2129  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2130  assert(!loopWrappers.empty());
2131 
2132  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2133  if (isa<SimdOp>(innermostWrapper))
2134  innermostWrapper = std::next(innermostWrapper);
2135 
2136  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2137  if (numWrappers != 1 && numWrappers != 2)
2138  return TargetRegionFlags::generic;
2139 
2140  // Detect target-teams-distribute-parallel-wsloop[-simd].
2141  if (numWrappers == 2) {
2142  if (!isa<WsloopOp>(innermostWrapper))
2143  return TargetRegionFlags::generic;
2144 
2145  innermostWrapper = std::next(innermostWrapper);
2146  if (!isa<DistributeOp>(innermostWrapper))
2147  return TargetRegionFlags::generic;
2148 
2149  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2150  if (!isa_and_present<ParallelOp>(parallelOp))
2151  return TargetRegionFlags::generic;
2152 
2153  Operation *teamsOp = parallelOp->getParentOp();
2154  if (!isa_and_present<TeamsOp>(teamsOp))
2155  return TargetRegionFlags::generic;
2156 
2157  if (teamsOp->getParentOp() == targetOp.getOperation())
2158  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2159  }
2160  // Detect target-teams-distribute[-simd] and target-teams-loop.
2161  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2162  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2163  if (!isa_and_present<TeamsOp>(teamsOp))
2164  return TargetRegionFlags::generic;
2165 
2166  if (teamsOp->getParentOp() != targetOp.getOperation())
2167  return TargetRegionFlags::generic;
2168 
2169  if (isa<LoopOp>(innermostWrapper))
2170  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2171 
2172  // Find single immediately nested captured omp.parallel and add spmd flag
2173  // (generic-spmd case).
2174  //
2175  // TODO: This shouldn't have to be done here, as it is too easy to break.
2176  // The openmp-opt pass should be updated to be able to promote kernels like
2177  // this from "Generic" to "Generic-SPMD". However, the use of the
2178  // `kmpc_distribute_static_loop` family of functions produced by the
2179  // OMPIRBuilder for these kernels prevents that from working.
2180  Dialect *ompDialect = targetOp->getDialect();
2181  Operation *nestedCapture = findCapturedOmpOp(
2182  capturedOp, /*checkSingleMandatoryExec=*/false,
2183  [&](Operation *sibling) {
2184  return sibling && (ompDialect != sibling->getDialect() ||
2185  sibling->hasTrait<OpTrait::IsTerminator>());
2186  });
2187 
2188  TargetRegionFlags result =
2189  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2190 
2191  if (!nestedCapture)
2192  return result;
2193 
2194  while (nestedCapture->getParentOp() != capturedOp)
2195  nestedCapture = nestedCapture->getParentOp();
2196 
2197  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2198  : result;
2199  }
2200  // Detect target-parallel-wsloop[-simd].
2201  else if (isa<WsloopOp>(innermostWrapper)) {
2202  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2203  if (!isa_and_present<ParallelOp>(parallelOp))
2204  return TargetRegionFlags::generic;
2205 
2206  if (parallelOp->getParentOp() == targetOp.getOperation())
2207  return TargetRegionFlags::spmd;
2208  }
2209 
2210  return TargetRegionFlags::generic;
2211 }
2212 
2213 //===----------------------------------------------------------------------===//
2214 // ParallelOp
2215 //===----------------------------------------------------------------------===//
2216 
2217 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2218  ArrayRef<NamedAttribute> attributes) {
2219  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2220  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2221  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2222  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2223  /*proc_bind_kind=*/nullptr,
2224  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2225  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2226  state.addAttributes(attributes);
2227 }
2228 
2229 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2230  const ParallelOperands &clauses) {
2231  MLIRContext *ctx = builder.getContext();
2232  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2233  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2234  makeArrayAttr(ctx, clauses.privateSyms),
2235  clauses.privateNeedsBarrier, clauses.procBindKind,
2236  clauses.reductionMod, clauses.reductionVars,
2237  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2238  makeArrayAttr(ctx, clauses.reductionSyms));
2239 }
2240 
2241 template <typename OpType>
2242 static LogicalResult verifyPrivateVarList(OpType &op) {
2243  auto privateVars = op.getPrivateVars();
2244  auto privateSyms = op.getPrivateSymsAttr();
2245 
2246  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2247  return success();
2248 
2249  auto numPrivateVars = privateVars.size();
2250  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2251 
2252  if (numPrivateVars != numPrivateSyms)
2253  return op.emitError() << "inconsistent number of private variables and "
2254  "privatizer op symbols, private vars: "
2255  << numPrivateVars
2256  << " vs. privatizer op symbols: " << numPrivateSyms;
2257 
2258  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2259  Type varType = std::get<0>(privateVarInfo).getType();
2260  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2261  PrivateClauseOp privatizerOp =
2262  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2263 
2264  if (privatizerOp == nullptr)
2265  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2266  << privateSym << "'";
2267 
2268  Type privatizerType = privatizerOp.getArgType();
2269 
2270  if (privatizerType && (varType != privatizerType))
2271  return op.emitError()
2272  << "type mismatch between a "
2273  << (privatizerOp.getDataSharingType() ==
2274  DataSharingClauseType::Private
2275  ? "private"
2276  : "firstprivate")
2277  << " variable and its privatizer op, var type: " << varType
2278  << " vs. privatizer op type: " << privatizerType;
2279  }
2280 
2281  return success();
2282 }
2283 
2284 LogicalResult ParallelOp::verify() {
2285  if (getAllocateVars().size() != getAllocatorVars().size())
2286  return emitError(
2287  "expected equal sizes for allocate and allocator variables");
2288 
2289  if (failed(verifyPrivateVarList(*this)))
2290  return failure();
2291 
2292  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2293  getReductionByref());
2294 }
2295 
2296 LogicalResult ParallelOp::verifyRegions() {
2297  auto distChildOps = getOps<DistributeOp>();
2298  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2299  if (numDistChildOps > 1)
2300  return emitError()
2301  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2302 
2303  if (numDistChildOps == 1) {
2304  if (!isComposite())
2305  return emitError()
2306  << "'omp.composite' attribute missing from composite operation";
2307 
2308  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2309  Operation &distributeOp = **distChildOps.begin();
2310  for (Operation &childOp : getOps()) {
2311  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2312  continue;
2313 
2314  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2315  return emitError() << "unexpected OpenMP operation inside of composite "
2316  "'omp.parallel': "
2317  << childOp.getName();
2318  }
2319  } else if (isComposite()) {
2320  return emitError()
2321  << "'omp.composite' attribute present in non-composite operation";
2322  }
2323  return success();
2324 }
2325 
2326 //===----------------------------------------------------------------------===//
2327 // TeamsOp
2328 //===----------------------------------------------------------------------===//
2329 
2331  while ((op = op->getParentOp()))
2332  if (isa<OpenMPDialect>(op->getDialect()))
2333  return false;
2334  return true;
2335 }
2336 
2337 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2338  const TeamsOperands &clauses) {
2339  MLIRContext *ctx = builder.getContext();
2340  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2341  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2342  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2343  /*private_vars=*/{}, /*private_syms=*/nullptr,
2344  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2345  clauses.reductionVars,
2346  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2347  makeArrayAttr(ctx, clauses.reductionSyms),
2348  clauses.threadLimit);
2349 }
2350 
2351 LogicalResult TeamsOp::verify() {
2352  // Check parent region
2353  // TODO If nested inside of a target region, also check that it does not
2354  // contain any statements, declarations or directives other than this
2355  // omp.teams construct. The issue is how to support the initialization of
2356  // this operation's own arguments (allow SSA values across omp.target?).
2357  Operation *op = getOperation();
2358  if (!isa<TargetOp>(op->getParentOp()) &&
2360  return emitError("expected to be nested inside of omp.target or not nested "
2361  "in any OpenMP dialect operations");
2362 
2363  // Check for num_teams clause restrictions
2364  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2365  auto numTeamsUpperBound = getNumTeamsUpper();
2366  if (!numTeamsUpperBound)
2367  return emitError("expected num_teams upper bound to be defined if the "
2368  "lower bound is defined");
2369  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2370  return emitError(
2371  "expected num_teams upper bound and lower bound to be the same type");
2372  }
2373 
2374  // Check for allocate clause restrictions
2375  if (getAllocateVars().size() != getAllocatorVars().size())
2376  return emitError(
2377  "expected equal sizes for allocate and allocator variables");
2378 
2379  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2380  getReductionByref());
2381 }
2382 
2383 //===----------------------------------------------------------------------===//
2384 // SectionOp
2385 //===----------------------------------------------------------------------===//
2386 
2387 OperandRange SectionOp::getPrivateVars() {
2388  return getParentOp().getPrivateVars();
2389 }
2390 
2391 OperandRange SectionOp::getReductionVars() {
2392  return getParentOp().getReductionVars();
2393 }
2394 
2395 //===----------------------------------------------------------------------===//
2396 // SectionsOp
2397 //===----------------------------------------------------------------------===//
2398 
2399 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2400  const SectionsOperands &clauses) {
2401  MLIRContext *ctx = builder.getContext();
2402  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2403  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2404  clauses.nowait, /*private_vars=*/{},
2405  /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2406  clauses.reductionMod, clauses.reductionVars,
2407  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2408  makeArrayAttr(ctx, clauses.reductionSyms));
2409 }
2410 
2411 LogicalResult SectionsOp::verify() {
2412  if (getAllocateVars().size() != getAllocatorVars().size())
2413  return emitError(
2414  "expected equal sizes for allocate and allocator variables");
2415 
2416  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2417  getReductionByref());
2418 }
2419 
2420 LogicalResult SectionsOp::verifyRegions() {
2421  for (auto &inst : *getRegion().begin()) {
2422  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2423  return emitOpError()
2424  << "expected omp.section op or terminator op inside region";
2425  }
2426  }
2427 
2428  return success();
2429 }
2430 
2431 //===----------------------------------------------------------------------===//
2432 // SingleOp
2433 //===----------------------------------------------------------------------===//
2434 
2435 void SingleOp::build(OpBuilder &builder, OperationState &state,
2436  const SingleOperands &clauses) {
2437  MLIRContext *ctx = builder.getContext();
2438  // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2439  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2440  clauses.copyprivateVars,
2441  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2442  /*private_vars=*/{}, /*private_syms=*/nullptr,
2443  /*private_needs_barrier=*/nullptr);
2444 }
2445 
2446 LogicalResult SingleOp::verify() {
2447  // Check for allocate clause restrictions
2448  if (getAllocateVars().size() != getAllocatorVars().size())
2449  return emitError(
2450  "expected equal sizes for allocate and allocator variables");
2451 
2452  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2453  getCopyprivateSyms());
2454 }
2455 
2456 //===----------------------------------------------------------------------===//
2457 // WorkshareOp
2458 //===----------------------------------------------------------------------===//
2459 
2460 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2461  const WorkshareOperands &clauses) {
2462  WorkshareOp::build(builder, state, clauses.nowait);
2463 }
2464 
2465 //===----------------------------------------------------------------------===//
2466 // WorkshareLoopWrapperOp
2467 //===----------------------------------------------------------------------===//
2468 
2469 LogicalResult WorkshareLoopWrapperOp::verify() {
2470  if (!(*this)->getParentOfType<WorkshareOp>())
2471  return emitOpError() << "must be nested in an omp.workshare";
2472  return success();
2473 }
2474 
2475 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2476  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2477  getNestedWrapper())
2478  return emitOpError() << "expected to be a standalone loop wrapper";
2479 
2480  return success();
2481 }
2482 
2483 //===----------------------------------------------------------------------===//
2484 // LoopWrapperInterface
2485 //===----------------------------------------------------------------------===//
2486 
2487 LogicalResult LoopWrapperInterface::verifyImpl() {
2488  Operation *op = this->getOperation();
2489  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2491  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2492  "and `SingleBlock` traits";
2493 
2494  if (op->getNumRegions() != 1)
2495  return emitOpError() << "loop wrapper does not contain exactly one region";
2496 
2497  Region &region = op->getRegion(0);
2498  if (range_size(region.getOps()) != 1)
2499  return emitOpError()
2500  << "loop wrapper does not contain exactly one nested op";
2501 
2502  Operation &firstOp = *region.op_begin();
2503  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2504  return emitOpError() << "nested in loop wrapper is not another loop "
2505  "wrapper or `omp.loop_nest`";
2506 
2507  return success();
2508 }
2509 
2510 //===----------------------------------------------------------------------===//
2511 // LoopOp
2512 //===----------------------------------------------------------------------===//
2513 
2514 void LoopOp::build(OpBuilder &builder, OperationState &state,
2515  const LoopOperands &clauses) {
2516  MLIRContext *ctx = builder.getContext();
2517 
2518  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2519  makeArrayAttr(ctx, clauses.privateSyms),
2520  clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2521  clauses.reductionMod, clauses.reductionVars,
2522  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2523  makeArrayAttr(ctx, clauses.reductionSyms));
2524 }
2525 
2526 LogicalResult LoopOp::verify() {
2527  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2528  getReductionByref());
2529 }
2530 
2531 LogicalResult LoopOp::verifyRegions() {
2532  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2533  getNestedWrapper())
2534  return emitOpError() << "expected to be a standalone loop wrapper";
2535 
2536  return success();
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // WsloopOp
2541 //===----------------------------------------------------------------------===//
2542 
2543 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2544  ArrayRef<NamedAttribute> attributes) {
2545  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2546  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2547  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2548  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2549  /*private_needs_barrier=*/false,
2550  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2551  /*reduction_byref=*/nullptr,
2552  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2553  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2554  /*schedule_simd=*/false);
2555  state.addAttributes(attributes);
2556 }
2557 
2558 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2559  const WsloopOperands &clauses) {
2560  MLIRContext *ctx = builder.getContext();
2561  // TODO: Store clauses in op: allocateVars, allocatorVars
2562  WsloopOp::build(
2563  builder, state,
2564  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2565  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2566  clauses.ordered, clauses.privateVars,
2567  makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2568  clauses.reductionMod, clauses.reductionVars,
2569  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2570  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2571  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2572 }
2573 
2574 LogicalResult WsloopOp::verify() {
2575  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2576  getReductionByref());
2577 }
2578 
2579 LogicalResult WsloopOp::verifyRegions() {
2580  bool isCompositeChildLeaf =
2581  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2582 
2583  if (LoopWrapperInterface nested = getNestedWrapper()) {
2584  if (!isComposite())
2585  return emitError()
2586  << "'omp.composite' attribute missing from composite wrapper";
2587 
2588  // Check for the allowed leaf constructs that may appear in a composite
2589  // construct directly after DO/FOR.
2590  if (!isa<SimdOp>(nested))
2591  return emitError() << "only supported nested wrapper is 'omp.simd'";
2592 
2593  } else if (isComposite() && !isCompositeChildLeaf) {
2594  return emitError()
2595  << "'omp.composite' attribute present in non-composite wrapper";
2596  } else if (!isComposite() && isCompositeChildLeaf) {
2597  return emitError()
2598  << "'omp.composite' attribute missing from composite wrapper";
2599  }
2600 
2601  return success();
2602 }
2603 
2604 //===----------------------------------------------------------------------===//
2605 // Simd construct [2.9.3.1]
2606 //===----------------------------------------------------------------------===//
2607 
2608 void SimdOp::build(OpBuilder &builder, OperationState &state,
2609  const SimdOperands &clauses) {
2610  MLIRContext *ctx = builder.getContext();
2611  // TODO Store clauses in op: linearVars, linearStepVars
2612  SimdOp::build(builder, state, clauses.alignedVars,
2613  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2614  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2615  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2616  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2617  clauses.privateNeedsBarrier, clauses.reductionMod,
2618  clauses.reductionVars,
2619  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2620  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2621  clauses.simdlen);
2622 }
2623 
2624 LogicalResult SimdOp::verify() {
2625  if (getSimdlen().has_value() && getSafelen().has_value() &&
2626  getSimdlen().value() > getSafelen().value())
2627  return emitOpError()
2628  << "simdlen clause and safelen clause are both present, but the "
2629  "simdlen value is not less than or equal to safelen value";
2630 
2631  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2632  return failure();
2633 
2634  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2635  return failure();
2636 
2637  bool isCompositeChildLeaf =
2638  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2639 
2640  if (!isComposite() && isCompositeChildLeaf)
2641  return emitError()
2642  << "'omp.composite' attribute missing from composite wrapper";
2643 
2644  if (isComposite() && !isCompositeChildLeaf)
2645  return emitError()
2646  << "'omp.composite' attribute present in non-composite wrapper";
2647 
2648  // Firstprivate is not allowed for SIMD in the standard. Check that none of
2649  // the private decls are for firstprivate.
2650  std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2651  if (privateSyms) {
2652  for (const Attribute &sym : *privateSyms) {
2653  auto symRef = cast<SymbolRefAttr>(sym);
2654  omp::PrivateClauseOp privatizer =
2655  SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2656  getOperation(), symRef);
2657  if (!privatizer)
2658  return emitError() << "Cannot find privatizer '" << symRef << "'";
2659  if (privatizer.getDataSharingType() ==
2660  DataSharingClauseType::FirstPrivate)
2661  return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2662  }
2663  }
2664 
2665  return success();
2666 }
2667 
2668 LogicalResult SimdOp::verifyRegions() {
2669  if (getNestedWrapper())
2670  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2671 
2672  return success();
2673 }
2674 
2675 //===----------------------------------------------------------------------===//
2676 // Distribute construct [2.9.4.1]
2677 //===----------------------------------------------------------------------===//
2678 
2679 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2680  const DistributeOperands &clauses) {
2681  DistributeOp::build(builder, state, clauses.allocateVars,
2682  clauses.allocatorVars, clauses.distScheduleStatic,
2683  clauses.distScheduleChunkSize, clauses.order,
2684  clauses.orderMod, clauses.privateVars,
2685  makeArrayAttr(builder.getContext(), clauses.privateSyms),
2686  clauses.privateNeedsBarrier);
2687 }
2688 
2689 LogicalResult DistributeOp::verify() {
2690  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2691  return emitOpError() << "chunk size set without "
2692  "dist_schedule_static being present";
2693 
2694  if (getAllocateVars().size() != getAllocatorVars().size())
2695  return emitError(
2696  "expected equal sizes for allocate and allocator variables");
2697 
2698  return success();
2699 }
2700 
2701 LogicalResult DistributeOp::verifyRegions() {
2702  if (LoopWrapperInterface nested = getNestedWrapper()) {
2703  if (!isComposite())
2704  return emitError()
2705  << "'omp.composite' attribute missing from composite wrapper";
2706  // Check for the allowed leaf constructs that may appear in a composite
2707  // construct directly after DISTRIBUTE.
2708  if (isa<WsloopOp>(nested)) {
2709  Operation *parentOp = (*this)->getParentOp();
2710  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2711  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2712  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2713  "when a composite 'omp.parallel' is the direct "
2714  "parent";
2715  }
2716  } else if (!isa<SimdOp>(nested))
2717  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2718  "'omp.wsloop'";
2719  } else if (isComposite()) {
2720  return emitError()
2721  << "'omp.composite' attribute present in non-composite wrapper";
2722  }
2723 
2724  return success();
2725 }
2726 
2727 //===----------------------------------------------------------------------===//
2728 // DeclareMapperOp / DeclareMapperInfoOp
2729 //===----------------------------------------------------------------------===//
2730 
2731 LogicalResult DeclareMapperInfoOp::verify() {
2732  return verifyMapClause(*this, getMapVars());
2733 }
2734 
2735 LogicalResult DeclareMapperOp::verifyRegions() {
2736  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2737  getRegion().getBlocks().front().getTerminator()))
2738  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2739 
2740  return success();
2741 }
2742 
2743 //===----------------------------------------------------------------------===//
2744 // DeclareReductionOp
2745 //===----------------------------------------------------------------------===//
2746 
2747 LogicalResult DeclareReductionOp::verifyRegions() {
2748  if (!getAllocRegion().empty()) {
2749  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2750  if (yieldOp.getResults().size() != 1 ||
2751  yieldOp.getResults().getTypes()[0] != getType())
2752  return emitOpError() << "expects alloc region to yield a value "
2753  "of the reduction type";
2754  }
2755  }
2756 
2757  if (getInitializerRegion().empty())
2758  return emitOpError() << "expects non-empty initializer region";
2759  Block &initializerEntryBlock = getInitializerRegion().front();
2760 
2761  if (initializerEntryBlock.getNumArguments() == 1) {
2762  if (!getAllocRegion().empty())
2763  return emitOpError() << "expects two arguments to the initializer region "
2764  "when an allocation region is used";
2765  } else if (initializerEntryBlock.getNumArguments() == 2) {
2766  if (getAllocRegion().empty())
2767  return emitOpError() << "expects one argument to the initializer region "
2768  "when no allocation region is used";
2769  } else {
2770  return emitOpError()
2771  << "expects one or two arguments to the initializer region";
2772  }
2773 
2774  for (mlir::Value arg : initializerEntryBlock.getArguments())
2775  if (arg.getType() != getType())
2776  return emitOpError() << "expects initializer region argument to match "
2777  "the reduction type";
2778 
2779  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2780  if (yieldOp.getResults().size() != 1 ||
2781  yieldOp.getResults().getTypes()[0] != getType())
2782  return emitOpError() << "expects initializer region to yield a value "
2783  "of the reduction type";
2784  }
2785 
2786  if (getReductionRegion().empty())
2787  return emitOpError() << "expects non-empty reduction region";
2788  Block &reductionEntryBlock = getReductionRegion().front();
2789  if (reductionEntryBlock.getNumArguments() != 2 ||
2790  reductionEntryBlock.getArgumentTypes()[0] !=
2791  reductionEntryBlock.getArgumentTypes()[1] ||
2792  reductionEntryBlock.getArgumentTypes()[0] != getType())
2793  return emitOpError() << "expects reduction region with two arguments of "
2794  "the reduction type";
2795  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2796  if (yieldOp.getResults().size() != 1 ||
2797  yieldOp.getResults().getTypes()[0] != getType())
2798  return emitOpError() << "expects reduction region to yield a value "
2799  "of the reduction type";
2800  }
2801 
2802  if (!getAtomicReductionRegion().empty()) {
2803  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2804  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2805  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2806  atomicReductionEntryBlock.getArgumentTypes()[1])
2807  return emitOpError() << "expects atomic reduction region with two "
2808  "arguments of the same type";
2809  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2810  atomicReductionEntryBlock.getArgumentTypes()[0]);
2811  if (!ptrType ||
2812  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2813  return emitOpError() << "expects atomic reduction region arguments to "
2814  "be accumulators containing the reduction type";
2815  }
2816 
2817  if (getCleanupRegion().empty())
2818  return success();
2819  Block &cleanupEntryBlock = getCleanupRegion().front();
2820  if (cleanupEntryBlock.getNumArguments() != 1 ||
2821  cleanupEntryBlock.getArgument(0).getType() != getType())
2822  return emitOpError() << "expects cleanup region with one argument "
2823  "of the reduction type";
2824 
2825  return success();
2826 }
2827 
2828 //===----------------------------------------------------------------------===//
2829 // TaskOp
2830 //===----------------------------------------------------------------------===//
2831 
2832 void TaskOp::build(OpBuilder &builder, OperationState &state,
2833  const TaskOperands &clauses) {
2834  MLIRContext *ctx = builder.getContext();
2835  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2836  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2837  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2838  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2839  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2840  clauses.priority, /*private_vars=*/clauses.privateVars,
2841  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2842  clauses.privateNeedsBarrier, clauses.untied,
2843  clauses.eventHandle);
2844 }
2845 
2846 LogicalResult TaskOp::verify() {
2847  LogicalResult verifyDependVars =
2848  verifyDependVarList(*this, getDependKinds(), getDependVars());
2849  return failed(verifyDependVars)
2850  ? verifyDependVars
2851  : verifyReductionVarList(*this, getInReductionSyms(),
2852  getInReductionVars(),
2853  getInReductionByref());
2854 }
2855 
2856 //===----------------------------------------------------------------------===//
2857 // TaskgroupOp
2858 //===----------------------------------------------------------------------===//
2859 
2860 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2861  const TaskgroupOperands &clauses) {
2862  MLIRContext *ctx = builder.getContext();
2863  TaskgroupOp::build(builder, state, clauses.allocateVars,
2864  clauses.allocatorVars, clauses.taskReductionVars,
2865  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2866  makeArrayAttr(ctx, clauses.taskReductionSyms));
2867 }
2868 
2869 LogicalResult TaskgroupOp::verify() {
2870  return verifyReductionVarList(*this, getTaskReductionSyms(),
2871  getTaskReductionVars(),
2872  getTaskReductionByref());
2873 }
2874 
2875 //===----------------------------------------------------------------------===//
2876 // TaskloopOp
2877 //===----------------------------------------------------------------------===//
2878 
2879 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2880  const TaskloopOperands &clauses) {
2881  MLIRContext *ctx = builder.getContext();
2882  TaskloopOp::build(
2883  builder, state, clauses.allocateVars, clauses.allocatorVars,
2884  clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2885  clauses.inReductionVars,
2886  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2887  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2888  clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2889  /*private_vars=*/clauses.privateVars,
2890  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2891  clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2892  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2893  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2894 }
2895 
2896 LogicalResult TaskloopOp::verify() {
2897  if (getAllocateVars().size() != getAllocatorVars().size())
2898  return emitError(
2899  "expected equal sizes for allocate and allocator variables");
2900  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2901  getReductionVars(), getReductionByref())) ||
2902  failed(verifyReductionVarList(*this, getInReductionSyms(),
2903  getInReductionVars(),
2904  getInReductionByref())))
2905  return failure();
2906 
2907  if (!getReductionVars().empty() && getNogroup())
2908  return emitError("if a reduction clause is present on the taskloop "
2909  "directive, the nogroup clause must not be specified");
2910  for (auto var : getReductionVars()) {
2911  if (llvm::is_contained(getInReductionVars(), var))
2912  return emitError("the same list item cannot appear in both a reduction "
2913  "and an in_reduction clause");
2914  }
2915 
2916  if (getGrainsize() && getNumTasks()) {
2917  return emitError(
2918  "the grainsize clause and num_tasks clause are mutually exclusive and "
2919  "may not appear on the same taskloop directive");
2920  }
2921 
2922  return success();
2923 }
2924 
2925 LogicalResult TaskloopOp::verifyRegions() {
2926  if (LoopWrapperInterface nested = getNestedWrapper()) {
2927  if (!isComposite())
2928  return emitError()
2929  << "'omp.composite' attribute missing from composite wrapper";
2930 
2931  // Check for the allowed leaf constructs that may appear in a composite
2932  // construct directly after TASKLOOP.
2933  if (!isa<SimdOp>(nested))
2934  return emitError() << "only supported nested wrapper is 'omp.simd'";
2935  } else if (isComposite()) {
2936  return emitError()
2937  << "'omp.composite' attribute present in non-composite wrapper";
2938  }
2939 
2940  return success();
2941 }
2942 
2943 //===----------------------------------------------------------------------===//
2944 // LoopNestOp
2945 //===----------------------------------------------------------------------===//
2946 
2947 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2948  // Parse an opening `(` followed by induction variables followed by `)`
2951  Type loopVarType;
2952  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2953  parser.parseColonType(loopVarType) ||
2954  // Parse loop bounds.
2955  parser.parseEqual() ||
2956  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2957  parser.parseKeyword("to") ||
2958  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2959  return failure();
2960 
2961  for (auto &iv : ivs)
2962  iv.type = loopVarType;
2963 
2964  auto *ctx = parser.getBuilder().getContext();
2965  // Parse "inclusive" flag.
2966  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2967  result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
2968 
2969  // Parse step values.
2971  if (parser.parseKeyword("step") ||
2972  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2973  return failure();
2974 
2975  // Parse collapse
2976  int64_t value = 0;
2977  if (!parser.parseOptionalKeyword("collapse") &&
2978  (parser.parseLParen() || parser.parseInteger(value) ||
2979  parser.parseRParen()))
2980  return failure();
2981  if (value > 1)
2982  result.addAttribute(
2983  "collapse_num_loops",
2984  IntegerAttr::get(parser.getBuilder().getI64Type(), value));
2985 
2986  // Parse tiles
2987  SmallVector<int64_t> tiles;
2988  auto parseTiles = [&]() -> ParseResult {
2989  int64_t tile;
2990  if (parser.parseInteger(tile))
2991  return failure();
2992  tiles.push_back(tile);
2993  return success();
2994  };
2995 
2996  if (!parser.parseOptionalKeyword("tiles") &&
2997  (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
2998  parser.parseRParen()))
2999  return failure();
3000 
3001  if (tiles.size() > 0)
3002  result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3003 
3004  // Parse the body.
3005  Region *region = result.addRegion();
3006  if (parser.parseRegion(*region, ivs))
3007  return failure();
3008 
3009  // Resolve operands.
3010  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3011  parser.resolveOperands(ubs, loopVarType, result.operands) ||
3012  parser.resolveOperands(steps, loopVarType, result.operands))
3013  return failure();
3014 
3015  // Parse the optional attribute list.
3016  return parser.parseOptionalAttrDict(result.attributes);
3017 }
3018 
3020  Region &region = getRegion();
3021  auto args = region.getArguments();
3022  p << " (" << args << ") : " << args[0].getType() << " = ("
3023  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3024  if (getLoopInclusive())
3025  p << "inclusive ";
3026  p << "step (" << getLoopSteps() << ") ";
3027  if (int64_t numCollapse = getCollapseNumLoops())
3028  if (numCollapse > 1)
3029  p << "collapse(" << numCollapse << ") ";
3030 
3031  if (const auto tiles = getTileSizes())
3032  p << "tiles(" << tiles.value() << ") ";
3033 
3034  p.printRegion(region, /*printEntryBlockArgs=*/false);
3035 }
3036 
3037 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3038  const LoopNestOperands &clauses) {
3039  MLIRContext *ctx = builder.getContext();
3040  LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3041  clauses.loopLowerBounds, clauses.loopUpperBounds,
3042  clauses.loopSteps, clauses.loopInclusive,
3043  makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3044 }
3045 
3046 LogicalResult LoopNestOp::verify() {
3047  if (getLoopLowerBounds().empty())
3048  return emitOpError() << "must represent at least one loop";
3049 
3050  if (getLoopLowerBounds().size() != getIVs().size())
3051  return emitOpError() << "number of range arguments and IVs do not match";
3052 
3053  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3054  if (lb.getType() != iv.getType())
3055  return emitOpError()
3056  << "range argument type does not match corresponding IV type";
3057  }
3058 
3059  uint64_t numIVs = getIVs().size();
3060 
3061  if (const auto &numCollapse = getCollapseNumLoops())
3062  if (numCollapse > numIVs)
3063  return emitOpError()
3064  << "collapse value is larger than the number of loops";
3065 
3066  if (const auto &tiles = getTileSizes())
3067  if (tiles.value().size() > numIVs)
3068  return emitOpError() << "too few canonical loops for tile dimensions";
3069 
3070  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3071  return emitOpError() << "expects parent op to be a loop wrapper";
3072 
3073  return success();
3074 }
3075 
3076 void LoopNestOp::gatherWrappers(
3078  Operation *parent = (*this)->getParentOp();
3079  while (auto wrapper =
3080  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3081  wrappers.push_back(wrapper);
3082  parent = parent->getParentOp();
3083  }
3084 }
3085 
3086 //===----------------------------------------------------------------------===//
3087 // OpenMP canonical loop handling
3088 //===----------------------------------------------------------------------===//
3089 
3090 std::tuple<NewCliOp, OpOperand *, OpOperand *>
3092 
3093  // Defining a CLI for a generated loop is optional; if there is none then
3094  // there is no followup-tranformation
3095  if (!cli)
3096  return {{}, nullptr, nullptr};
3097 
3098  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3099  "Unexpected type of cli");
3100 
3101  NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3102  OpOperand *gen = nullptr;
3103  OpOperand *cons = nullptr;
3104  for (OpOperand &use : cli.getUses()) {
3105  auto op = cast<LoopTransformationInterface>(use.getOwner());
3106 
3107  unsigned opnum = use.getOperandNumber();
3108  if (op.isGeneratee(opnum)) {
3109  assert(!gen && "Each CLI may have at most one def");
3110  gen = &use;
3111  } else if (op.isApplyee(opnum)) {
3112  assert(!cons && "Each CLI may have at most one consumer");
3113  cons = &use;
3114  } else {
3115  llvm_unreachable("Unexpected operand for a CLI");
3116  }
3117  }
3118 
3119  return {create, gen, cons};
3120 }
3121 
3122 void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3123  ::mlir::OperationState &odsState) {
3124  odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3125 }
3126 
3127 void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3128  Value result = getResult();
3129  auto [newCli, gen, cons] = decodeCli(result);
3130 
3131  // Derive the CLI variable name from its generator:
3132  // * "canonloop" for omp.canonical_loop
3133  // * custom name for loop transformation generatees
3134  // * "cli" as fallback if no generator
3135  // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3136  // at that level
3137  // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3138  // the index of that region
3139  std::string cliName{"cli"};
3140  if (gen) {
3141  cliName =
3142  TypeSwitch<Operation *, std::string>(gen->getOwner())
3143  .Case([&](CanonicalLoopOp op) {
3144  // Find the canonical loop nesting: For each ancestor add a
3145  // "+_r<idx>" suffix (in reverse order)
3146  SmallVector<std::string> components;
3147  Operation *o = op.getOperation();
3148  while (o) {
3150  break;
3151 
3152  Region *r = o->getParentRegion();
3153  if (!r)
3154  break;
3155 
3156  auto getSequentialIndex = [](Region *r, Operation *o) {
3157  llvm::ReversePostOrderTraversal<Block *> traversal(
3158  &r->getBlocks().front());
3159  size_t idx = 0;
3160  for (Block *b : traversal) {
3161  for (Operation &op : *b) {
3162  if (&op == o)
3163  return idx;
3164  // Only consider operations that are containers as
3165  // possible children
3166  if (!op.getRegions().empty())
3167  idx += 1;
3168  }
3169  }
3170  llvm_unreachable("Operation not part of the region");
3171  };
3172  size_t sequentialIdx = getSequentialIndex(r, o);
3173  components.push_back(("s" + Twine(sequentialIdx)).str());
3174 
3175  Operation *parent = r->getParentOp();
3176  if (!parent)
3177  break;
3178 
3179  // If the operation has more than one region, also count in
3180  // which of the regions
3181  if (parent->getRegions().size() > 1) {
3182  auto getRegionIndex = [](Operation *o, Region *r) {
3183  for (auto [idx, region] :
3184  llvm::enumerate(o->getRegions())) {
3185  if (&region == r)
3186  return idx;
3187  }
3188  llvm_unreachable("Region not child its parent operation");
3189  };
3190  size_t regionIdx = getRegionIndex(parent, r);
3191  components.push_back(("r" + Twine(regionIdx)).str());
3192  }
3193 
3194  // next parent
3195  o = parent;
3196  }
3197 
3198  SmallString<64> Name("canonloop");
3199  for (const std::string &s : reverse(components)) {
3200  Name += '_';
3201  Name += s;
3202  }
3203 
3204  return Name;
3205  })
3206  .Case([&](UnrollHeuristicOp op) -> std::string {
3207  llvm_unreachable("heuristic unrolling does not generate a loop");
3208  })
3209  .Default([&](Operation *op) {
3210  assert(false && "TODO: Custom name for this operation");
3211  return "transformed";
3212  });
3213  }
3214 
3215  setNameFn(result, cliName);
3216 }
3217 
3218 LogicalResult NewCliOp::verify() {
3219  Value cli = getResult();
3220 
3221  assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3222  "Unexpected type of cli");
3223 
3224  // Check that the CLI is used in at most generator and one consumer
3225  OpOperand *gen = nullptr;
3226  OpOperand *cons = nullptr;
3227  for (mlir::OpOperand &use : cli.getUses()) {
3228  auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3229 
3230  unsigned opnum = use.getOperandNumber();
3231  if (op.isGeneratee(opnum)) {
3232  if (gen) {
3233  InFlightDiagnostic error =
3234  emitOpError("CLI must have at most one generator");
3235  error.attachNote(gen->getOwner()->getLoc())
3236  .append("first generator here:");
3237  error.attachNote(use.getOwner()->getLoc())
3238  .append("second generator here:");
3239  return error;
3240  }
3241 
3242  gen = &use;
3243  } else if (op.isApplyee(opnum)) {
3244  if (cons) {
3245  InFlightDiagnostic error =
3246  emitOpError("CLI must have at most one consumer");
3247  error.attachNote(cons->getOwner()->getLoc())
3248  .append("first consumer here:")
3249  .appendOp(*cons->getOwner(),
3250  OpPrintingFlags().printGenericOpForm());
3251  error.attachNote(use.getOwner()->getLoc())
3252  .append("second consumer here:")
3253  .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3254  return error;
3255  }
3256 
3257  cons = &use;
3258  } else {
3259  llvm_unreachable("Unexpected operand for a CLI");
3260  }
3261  }
3262 
3263  // If the CLI is source of a transformation, it must have a generator
3264  if (cons && !gen) {
3265  InFlightDiagnostic error = emitOpError("CLI has no generator");
3266  error.attachNote(cons->getOwner()->getLoc())
3267  .append("see consumer here: ")
3268  .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3269  return error;
3270  }
3271 
3272  return success();
3273 }
3274 
3275 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3276  Value tripCount) {
3277  odsState.addOperands(tripCount);
3278  odsState.addOperands(Value());
3279  (void)odsState.addRegion();
3280 }
3281 
3282 void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3283  Value tripCount, ::mlir::Value cli) {
3284  odsState.addOperands(tripCount);
3285  odsState.addOperands(cli);
3286  (void)odsState.addRegion();
3287 }
3288 
3289 void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3290  setNameFn(&getRegion().front(), "body_entry");
3291 }
3292 
3293 void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3294  OpAsmSetValueNameFn setNameFn) {
3295  setNameFn(region.getArgument(0), "iv");
3296 }
3297 
3299  if (getCli())
3300  p << '(' << getCli() << ')';
3301  p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3302  << " in range(" << getTripCount() << ") ";
3303 
3304  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3305  /*printBlockTerminators=*/true);
3306 
3307  p.printOptionalAttrDict((*this)->getAttrs());
3308 }
3309 
3310 mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3311  ::mlir::OperationState &result) {
3312  CanonicalLoopInfoType cliType =
3314 
3315  // Parse (optional) omp.cli identifier
3317  SmallVector<mlir::Value, 1> cliOperand;
3318  if (!parser.parseOptionalLParen()) {
3319  if (parser.parseOperand(cli) ||
3320  parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3321  return failure();
3322  }
3323 
3324  // We derive the type of tripCount from inductionVariable. MLIR requires the
3325  // type of tripCount to be known when calling resolveOperand so we have parse
3326  // the type before processing the inductionVariable.
3327  OpAsmParser::Argument inductionVariable;
3329  if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3330  parser.parseKeyword("in") || parser.parseKeyword("range") ||
3331  parser.parseLParen() || parser.parseOperand(tripcount) ||
3332  parser.parseRParen() ||
3333  parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3334  return failure();
3335 
3336  // Parse the loop body.
3337  Region *region = result.addRegion();
3338  if (parser.parseRegion(*region, {inductionVariable}))
3339  return failure();
3340 
3341  // We parsed the cli operand forst, but because it is optional, it must be
3342  // last in the operand list.
3343  result.operands.append(cliOperand);
3344 
3345  // Parse the optional attribute list.
3346  if (parser.parseOptionalAttrDict(result.attributes))
3347  return failure();
3348 
3349  return mlir::success();
3350 }
3351 
3352 LogicalResult CanonicalLoopOp::verify() {
3353  // The region's entry must accept the induction variable
3354  // It can also be empty if just created
3355  if (!getRegion().empty()) {
3356  Region &region = getRegion();
3357  if (region.getNumArguments() != 1)
3358  return emitOpError(
3359  "Canonical loop region must have exactly one argument");
3360 
3361  if (getInductionVar().getType() != getTripCount().getType())
3362  return emitOpError(
3363  "Region argument must be the same type as the trip count");
3364  }
3365 
3366  return success();
3367 }
3368 
3369 Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3370 
3371 std::pair<unsigned, unsigned>
3372 CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3373  // No applyees
3374  return {0, 0};
3375 }
3376 
3377 std::pair<unsigned, unsigned>
3378 CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3379  return getODSOperandIndexAndLength(odsIndex_cli);
3380 }
3381 
3382 //===----------------------------------------------------------------------===//
3383 // UnrollHeuristicOp
3384 //===----------------------------------------------------------------------===//
3385 
3386 void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3387  ::mlir::OperationState &odsState,
3388  ::mlir::Value cli) {
3389  odsState.addOperands(cli);
3390 }
3391 
3393  p << '(' << getApplyee() << ')';
3394 
3395  p.printOptionalAttrDict((*this)->getAttrs());
3396 }
3397 
3398 mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3399  ::mlir::OperationState &result) {
3400  auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3401 
3402  if (parser.parseLParen())
3403  return failure();
3404 
3406  if (parser.parseOperand(applyee) ||
3407  parser.resolveOperand(applyee, cliType, result.operands))
3408  return failure();
3409 
3410  if (parser.parseRParen())
3411  return failure();
3412 
3413  // Optional output loop (full unrolling has none)
3414  if (!parser.parseOptionalArrow()) {
3415  if (parser.parseLParen() || parser.parseRParen())
3416  return failure();
3417  }
3418 
3419  // Parse the optional attribute list.
3420  if (parser.parseOptionalAttrDict(result.attributes))
3421  return failure();
3422 
3423  return mlir::success();
3424 }
3425 
3426 std::pair<unsigned, unsigned>
3427 UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3428  return getODSOperandIndexAndLength(odsIndex_applyee);
3429 }
3430 
3431 std::pair<unsigned, unsigned>
3432 UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3433  return {0, 0};
3434 }
3435 
3436 //===----------------------------------------------------------------------===//
3437 // Critical construct (2.17.1)
3438 //===----------------------------------------------------------------------===//
3439 
3440 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3441  const CriticalDeclareOperands &clauses) {
3442  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3443 }
3444 
3445 LogicalResult CriticalDeclareOp::verify() {
3446  return verifySynchronizationHint(*this, getHint());
3447 }
3448 
3449 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3450  if (getNameAttr()) {
3451  SymbolRefAttr symbolRef = getNameAttr();
3452  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3453  *this, symbolRef);
3454  if (!decl) {
3455  return emitOpError() << "expected symbol reference " << symbolRef
3456  << " to point to a critical declaration";
3457  }
3458  }
3459 
3460  return success();
3461 }
3462 
3463 //===----------------------------------------------------------------------===//
3464 // Ordered construct
3465 //===----------------------------------------------------------------------===//
3466 
3467 static LogicalResult verifyOrderedParent(Operation &op) {
3468  bool hasRegion = op.getNumRegions() > 0;
3469  auto loopOp = op.getParentOfType<LoopNestOp>();
3470  if (!loopOp) {
3471  if (hasRegion)
3472  return success();
3473 
3474  // TODO: Consider if this needs to be the case only for the standalone
3475  // variant of the ordered construct.
3476  return op.emitOpError() << "must be nested inside of a loop";
3477  }
3478 
3479  Operation *wrapper = loopOp->getParentOp();
3480  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3481  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3482  if (!orderedAttr)
3483  return op.emitOpError() << "the enclosing worksharing-loop region must "
3484  "have an ordered clause";
3485 
3486  if (hasRegion && orderedAttr.getInt() != 0)
3487  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3488  "have a parameter present";
3489 
3490  if (!hasRegion && orderedAttr.getInt() == 0)
3491  return op.emitOpError() << "the enclosing loop's ordered clause must "
3492  "have a parameter present";
3493  } else if (!isa<SimdOp>(wrapper)) {
3494  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3495  "or worksharing simd loop";
3496  }
3497  return success();
3498 }
3499 
3500 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3501  const OrderedOperands &clauses) {
3502  OrderedOp::build(builder, state, clauses.doacrossDependType,
3503  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3504 }
3505 
3506 LogicalResult OrderedOp::verify() {
3507  if (failed(verifyOrderedParent(**this)))
3508  return failure();
3509 
3510  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3511  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3512  return emitOpError() << "number of variables in depend clause does not "
3513  << "match number of iteration variables in the "
3514  << "doacross loop";
3515 
3516  return success();
3517 }
3518 
3519 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3520  const OrderedRegionOperands &clauses) {
3521  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3522 }
3523 
3524 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3525 
3526 //===----------------------------------------------------------------------===//
3527 // TaskwaitOp
3528 //===----------------------------------------------------------------------===//
3529 
3530 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3531  const TaskwaitOperands &clauses) {
3532  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3533  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3534  /*depend_vars=*/{}, /*nowait=*/nullptr);
3535 }
3536 
3537 //===----------------------------------------------------------------------===//
3538 // Verifier for AtomicReadOp
3539 //===----------------------------------------------------------------------===//
3540 
3541 LogicalResult AtomicReadOp::verify() {
3542  if (verifyCommon().failed())
3543  return mlir::failure();
3544 
3545  if (auto mo = getMemoryOrder()) {
3546  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3547  *mo == ClauseMemoryOrderKind::Release) {
3548  return emitError(
3549  "memory-order must not be acq_rel or release for atomic reads");
3550  }
3551  }
3552  return verifySynchronizationHint(*this, getHint());
3553 }
3554 
3555 //===----------------------------------------------------------------------===//
3556 // Verifier for AtomicWriteOp
3557 //===----------------------------------------------------------------------===//
3558 
3559 LogicalResult AtomicWriteOp::verify() {
3560  if (verifyCommon().failed())
3561  return mlir::failure();
3562 
3563  if (auto mo = getMemoryOrder()) {
3564  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3565  *mo == ClauseMemoryOrderKind::Acquire) {
3566  return emitError(
3567  "memory-order must not be acq_rel or acquire for atomic writes");
3568  }
3569  }
3570  return verifySynchronizationHint(*this, getHint());
3571 }
3572 
3573 //===----------------------------------------------------------------------===//
3574 // Verifier for AtomicUpdateOp
3575 //===----------------------------------------------------------------------===//
3576 
3577 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3578  PatternRewriter &rewriter) {
3579  if (op.isNoOp()) {
3580  rewriter.eraseOp(op);
3581  return success();
3582  }
3583  if (Value writeVal = op.getWriteOpVal()) {
3584  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3585  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3586  return success();
3587  }
3588  return failure();
3589 }
3590 
3591 LogicalResult AtomicUpdateOp::verify() {
3592  if (verifyCommon().failed())
3593  return mlir::failure();
3594 
3595  if (auto mo = getMemoryOrder()) {
3596  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3597  *mo == ClauseMemoryOrderKind::Acquire) {
3598  return emitError(
3599  "memory-order must not be acq_rel or acquire for atomic updates");
3600  }
3601  }
3602 
3603  return verifySynchronizationHint(*this, getHint());
3604 }
3605 
3606 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3607 
3608 //===----------------------------------------------------------------------===//
3609 // Verifier for AtomicCaptureOp
3610 //===----------------------------------------------------------------------===//
3611 
3612 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3613  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3614  return op;
3615  return dyn_cast<AtomicReadOp>(getSecondOp());
3616 }
3617 
3618 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3619  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3620  return op;
3621  return dyn_cast<AtomicWriteOp>(getSecondOp());
3622 }
3623 
3624 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3625  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3626  return op;
3627  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3628 }
3629 
3630 LogicalResult AtomicCaptureOp::verify() {
3631  return verifySynchronizationHint(*this, getHint());
3632 }
3633 
3634 LogicalResult AtomicCaptureOp::verifyRegions() {
3635  if (verifyRegionsCommon().failed())
3636  return mlir::failure();
3637 
3638  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3639  return emitOpError(
3640  "operations inside capture region must not have hint clause");
3641 
3642  if (getFirstOp()->getAttr("memory_order") ||
3643  getSecondOp()->getAttr("memory_order"))
3644  return emitOpError(
3645  "operations inside capture region must not have memory_order clause");
3646  return success();
3647 }
3648 
3649 //===----------------------------------------------------------------------===//
3650 // CancelOp
3651 //===----------------------------------------------------------------------===//
3652 
3653 void CancelOp::build(OpBuilder &builder, OperationState &state,
3654  const CancelOperands &clauses) {
3655  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3656 }
3657 
3659  Operation *parent = thisOp->getParentOp();
3660  while (parent) {
3661  if (parent->getDialect() == thisOp->getDialect())
3662  return parent;
3663  parent = parent->getParentOp();
3664  }
3665  return nullptr;
3666 }
3667 
3668 LogicalResult CancelOp::verify() {
3669  ClauseCancellationConstructType cct = getCancelDirective();
3670  // The next OpenMP operation in the chain of parents
3671  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3672  if (!structuralParent)
3673  return emitOpError() << "Orphaned cancel construct";
3674 
3675  if ((cct == ClauseCancellationConstructType::Parallel) &&
3676  !mlir::isa<ParallelOp>(structuralParent)) {
3677  return emitOpError() << "cancel parallel must appear "
3678  << "inside a parallel region";
3679  }
3680  if (cct == ClauseCancellationConstructType::Loop) {
3681  // structural parent will be omp.loop_nest, directly nested inside
3682  // omp.wsloop
3683  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3684 
3685  if (!wsloopOp) {
3686  return emitOpError()
3687  << "cancel loop must appear inside a worksharing-loop region";
3688  }
3689  if (wsloopOp.getNowaitAttr()) {
3690  return emitError() << "A worksharing construct that is canceled "
3691  << "must not have a nowait clause";
3692  }
3693  if (wsloopOp.getOrderedAttr()) {
3694  return emitError() << "A worksharing construct that is canceled "
3695  << "must not have an ordered clause";
3696  }
3697 
3698  } else if (cct == ClauseCancellationConstructType::Sections) {
3699  // structural parent will be an omp.section, directly nested inside
3700  // omp.sections
3701  auto sectionsOp =
3702  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3703  if (!sectionsOp) {
3704  return emitOpError() << "cancel sections must appear "
3705  << "inside a sections region";
3706  }
3707  if (sectionsOp.getNowait()) {
3708  return emitError() << "A sections construct that is canceled "
3709  << "must not have a nowait clause";
3710  }
3711  }
3712  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3713  (!mlir::isa<omp::TaskOp>(structuralParent) &&
3714  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
3715  return emitOpError() << "cancel taskgroup must appear "
3716  << "inside a task region";
3717  }
3718  return success();
3719 }
3720 
3721 //===----------------------------------------------------------------------===//
3722 // CancellationPointOp
3723 //===----------------------------------------------------------------------===//
3724 
3725 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3726  const CancellationPointOperands &clauses) {
3727  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3728 }
3729 
3730 LogicalResult CancellationPointOp::verify() {
3731  ClauseCancellationConstructType cct = getCancelDirective();
3732  // The next OpenMP operation in the chain of parents
3733  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3734  if (!structuralParent)
3735  return emitOpError() << "Orphaned cancellation point";
3736 
3737  if ((cct == ClauseCancellationConstructType::Parallel) &&
3738  !mlir::isa<ParallelOp>(structuralParent)) {
3739  return emitOpError() << "cancellation point parallel must appear "
3740  << "inside a parallel region";
3741  }
3742  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3743  // find the wsloop
3744  if ((cct == ClauseCancellationConstructType::Loop) &&
3745  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3746  return emitOpError() << "cancellation point loop must appear "
3747  << "inside a worksharing-loop region";
3748  }
3749  if ((cct == ClauseCancellationConstructType::Sections) &&
3750  !mlir::isa<omp::SectionOp>(structuralParent)) {
3751  return emitOpError() << "cancellation point sections must appear "
3752  << "inside a sections region";
3753  }
3754  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3755  !mlir::isa<omp::TaskOp>(structuralParent)) {
3756  return emitOpError() << "cancellation point taskgroup must appear "
3757  << "inside a task region";
3758  }
3759  return success();
3760 }
3761 
3762 //===----------------------------------------------------------------------===//
3763 // MapBoundsOp
3764 //===----------------------------------------------------------------------===//
3765 
3766 LogicalResult MapBoundsOp::verify() {
3767  auto extent = getExtent();
3768  auto upperbound = getUpperBound();
3769  if (!extent && !upperbound)
3770  return emitError("expected extent or upperbound.");
3771  return success();
3772 }
3773 
3774 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3775  TypeRange /*result_types*/, StringAttr symName,
3776  TypeAttr type) {
3777  PrivateClauseOp::build(
3778  odsBuilder, odsState, symName, type,
3780  DataSharingClauseType::Private));
3781 }
3782 
3783 LogicalResult PrivateClauseOp::verifyRegions() {
3784  Type argType = getArgType();
3785  auto verifyTerminator = [&](Operation *terminator,
3786  bool yieldsValue) -> LogicalResult {
3787  if (!terminator->getBlock()->getSuccessors().empty())
3788  return success();
3789 
3790  if (!llvm::isa<YieldOp>(terminator))
3791  return mlir::emitError(terminator->getLoc())
3792  << "expected exit block terminator to be an `omp.yield` op.";
3793 
3794  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3795  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3796 
3797  if (!yieldsValue) {
3798  if (yieldedTypes.empty())
3799  return success();
3800 
3801  return mlir::emitError(terminator->getLoc())
3802  << "Did not expect any values to be yielded.";
3803  }
3804 
3805  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3806  return success();
3807 
3808  auto error = mlir::emitError(yieldOp.getLoc())
3809  << "Invalid yielded value. Expected type: " << argType
3810  << ", got: ";
3811 
3812  if (yieldedTypes.empty())
3813  error << "None";
3814  else
3815  error << yieldedTypes;
3816 
3817  return error;
3818  };
3819 
3820  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3821  StringRef regionName,
3822  bool yieldsValue) -> LogicalResult {
3823  assert(!region.empty());
3824 
3825  if (region.getNumArguments() != expectedNumArgs)
3826  return mlir::emitError(region.getLoc())
3827  << "`" << regionName << "`: "
3828  << "expected " << expectedNumArgs
3829  << " region arguments, got: " << region.getNumArguments();
3830 
3831  for (Block &block : region) {
3832  // MLIR will verify the absence of the terminator for us.
3833  if (!block.mightHaveTerminator())
3834  continue;
3835 
3836  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3837  return failure();
3838  }
3839 
3840  return success();
3841  };
3842 
3843  // Ensure all of the region arguments have the same type
3844  for (Region *region : getRegions())
3845  for (Type ty : region->getArgumentTypes())
3846  if (ty != argType)
3847  return emitError() << "Region argument type mismatch: got " << ty
3848  << " expected " << argType << ".";
3849 
3850  mlir::Region &initRegion = getInitRegion();
3851  if (!initRegion.empty() &&
3852  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3853  /*yieldsValue=*/true)))
3854  return failure();
3855 
3856  DataSharingClauseType dsType = getDataSharingType();
3857 
3858  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3859  return emitError("`private` clauses do not require a `copy` region.");
3860 
3861  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3862  return emitError(
3863  "`firstprivate` clauses require at least a `copy` region.");
3864 
3865  if (dsType == DataSharingClauseType::FirstPrivate &&
3866  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3867  /*yieldsValue=*/true)))
3868  return failure();
3869 
3870  if (!getDeallocRegion().empty() &&
3871  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3872  /*yieldsValue=*/false)))
3873  return failure();
3874 
3875  return success();
3876 }
3877 
3878 //===----------------------------------------------------------------------===//
3879 // Spec 5.2: Masked construct (10.5)
3880 //===----------------------------------------------------------------------===//
3881 
3882 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3883  const MaskedOperands &clauses) {
3884  MaskedOp::build(builder, state, clauses.filteredThreadId);
3885 }
3886 
3887 //===----------------------------------------------------------------------===//
3888 // Spec 5.2: Scan construct (5.6)
3889 //===----------------------------------------------------------------------===//
3890 
3891 void ScanOp::build(OpBuilder &builder, OperationState &state,
3892  const ScanOperands &clauses) {
3893  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3894 }
3895 
3896 LogicalResult ScanOp::verify() {
3897  if (hasExclusiveVars() == hasInclusiveVars())
3898  return emitError(
3899  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3900  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3901  if (parentWsLoopOp.getReductionModAttr() &&
3902  parentWsLoopOp.getReductionModAttr().getValue() ==
3903  ReductionModifier::inscan)
3904  return success();
3905  }
3906  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3907  if (parentSimdOp.getReductionModAttr() &&
3908  parentSimdOp.getReductionModAttr().getValue() ==
3909  ReductionModifier::inscan)
3910  return success();
3911  }
3912  return emitError("SCAN directive needs to be enclosed within a parent "
3913  "worksharing loop construct or SIMD construct with INSCAN "
3914  "reduction modifier");
3915 }
3916 
3917 /// Verifies align clause in allocate directive
3918 
3919 LogicalResult AllocateDirOp::verify() {
3920  std::optional<uint64_t> align = this->getAlign();
3921 
3922  if (align.has_value()) {
3923  if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
3924  return emitError() << "ALIGN value : " << align.value()
3925  << " must be power of 2";
3926  }
3927 
3928  return success();
3929 }
3930 
3931 //===----------------------------------------------------------------------===//
3932 // TargetAllocMemOp
3933 //===----------------------------------------------------------------------===//
3934 
3935 mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
3936  return getInTypeAttr().getValue();
3937 }
3938 
3939 /// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
3940 /// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
3941 /// attr-dict-without-keyword
3942 static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
3943  mlir::OperationState &result) {
3944  auto &builder = parser.getBuilder();
3945  bool hasOperands = false;
3946  std::int32_t typeparamsSize = 0;
3947 
3948  // Parse device number as a new operand
3950  mlir::Type deviceType;
3951  if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
3952  return mlir::failure();
3953  if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
3954  return mlir::failure();
3955  if (parser.parseComma())
3956  return mlir::failure();
3957 
3958  mlir::Type intype;
3959  if (parser.parseType(intype))
3960  return mlir::failure();
3961  result.addAttribute("in_type", mlir::TypeAttr::get(intype));
3964  if (!parser.parseOptionalLParen()) {
3965  // parse the LEN params of the derived type. (<params> : <types>)
3966  if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
3967  parser.parseColonTypeList(typeVec) || parser.parseRParen())
3968  return mlir::failure();
3969  typeparamsSize = operands.size();
3970  hasOperands = true;
3971  }
3972  std::int32_t shapeSize = 0;
3973  if (!parser.parseOptionalComma()) {
3974  // parse size to scale by, vector of n dimensions of type index
3976  return mlir::failure();
3977  shapeSize = operands.size() - typeparamsSize;
3978  auto idxTy = builder.getIndexType();
3979  for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
3980  typeVec.push_back(idxTy);
3981  hasOperands = true;
3982  }
3983  if (hasOperands &&
3984  parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
3985  result.operands))
3986  return mlir::failure();
3987 
3988  mlir::Type restype = builder.getIntegerType(64);
3989  if (!restype) {
3990  parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
3991  return mlir::failure();
3992  }
3993  llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
3994  result.addAttribute("operandSegmentSizes",
3995  builder.getDenseI32ArrayAttr(segmentSizes));
3996  if (parser.parseOptionalAttrDict(result.attributes) ||
3997  parser.addTypeToList(restype, result.types))
3998  return mlir::failure();
3999  return mlir::success();
4000 }
4001 
4002 mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4003  mlir::OperationState &result) {
4004  return parseTargetAllocMemOp(parser, result);
4005 }
4006 
4008  p << " ";
4009  p.printOperand(getDevice());
4010  p << " : ";
4011  p << getDevice().getType();
4012  p << ", ";
4013  p << getInType();
4014  if (!getTypeparams().empty()) {
4015  p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4016  }
4017  for (auto sh : getShape()) {
4018  p << ", ";
4019  p.printOperand(sh);
4020  }
4021  p.printOptionalAttrDict((*this)->getAttrs(),
4022  {"in_type", "operandSegmentSizes"});
4023 }
4024 
4025 llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4026  mlir::Type outType = getType();
4027  if (!mlir::dyn_cast<IntegerType>(outType))
4028  return emitOpError("must be a integer type");
4029  return mlir::success();
4030 }
4031 
4032 //===----------------------------------------------------------------------===//
4033 // WorkdistributeOp
4034 //===----------------------------------------------------------------------===//
4035 
4036 LogicalResult WorkdistributeOp::verify() {
4037  // Check that region exists and is not empty
4038  Region &region = getRegion();
4039  if (region.empty())
4040  return emitOpError("region cannot be empty");
4041  // Verify single entry point.
4042  Block &entryBlock = region.front();
4043  if (entryBlock.empty())
4044  return emitOpError("region must contain a structured block");
4045  // Verify single exit point.
4046  bool hasTerminator = false;
4047  for (Block &block : region) {
4048  if (isa<TerminatorOp>(block.back())) {
4049  if (hasTerminator) {
4050  return emitOpError("region must have exactly one terminator");
4051  }
4052  hasTerminator = true;
4053  }
4054  }
4055  if (!hasTerminator) {
4056  return emitOpError("region must be terminated with omp.terminator");
4057  }
4058  auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4059  // No implicit barrier at end
4060  if (isa<BarrierOp>(op)) {
4061  return emitOpError(
4062  "explicit barriers are not allowed in workdistribute region");
4063  }
4064  // Check for invalid nested constructs
4065  if (isa<ParallelOp>(op)) {
4066  return emitOpError(
4067  "nested parallel constructs not allowed in workdistribute");
4068  }
4069  if (isa<TeamsOp>(op)) {
4070  return emitOpError(
4071  "nested teams constructs not allowed in workdistribute");
4072  }
4073  return WalkResult::advance();
4074  });
4075  if (walkResult.wasInterrupted())
4076  return failure();
4077 
4078  Operation *parentOp = (*this)->getParentOp();
4079  if (!llvm::dyn_cast<TeamsOp>(parentOp))
4080  return emitOpError("workdistribute must be nested under teams");
4081  return success();
4082 }
4083 
4084 #define GET_ATTRDEF_CLASSES
4085 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4086 
4087 #define GET_OP_CLASSES
4088 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4089 
4090 #define GET_TYPEDEF_CLASSES
4091 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1365
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:270
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Definition: Diagnostics.h:228
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:352
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:769
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:306
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1584
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:881
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.