MLIR  19.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
22 
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/STLForwardCompat.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Frontend/OpenMP/OMPConstants.h"
31 #include <cstddef>
32 #include <iterator>
33 #include <optional>
34 
35 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
36 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
37 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
38 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
40 
41 using namespace mlir;
42 using namespace mlir::omp;
43 
44 namespace {
45 struct MemRefPointerLikeModel
46  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
47  MemRefType> {
48  Type getElementType(Type pointer) const {
49  return llvm::cast<MemRefType>(pointer).getElementType();
50  }
51 };
52 
53 struct LLVMPointerPointerLikeModel
54  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
55  LLVM::LLVMPointerType> {
56  Type getElementType(Type pointer) const { return Type(); }
57 };
58 
59 struct OpenMPDialectFoldInterface : public DialectFoldInterface {
61 
62  bool shouldMaterializeInto(Region *region) const final {
63  // Avoid folding constants across target regions
64  return isa<TargetOp>(region->getParentOp());
65  }
66 };
67 } // namespace
68 
69 void OpenMPDialect::initialize() {
70  addOperations<
71 #define GET_OP_LIST
72 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
73  >();
74  addAttributes<
75 #define GET_ATTRDEF_LIST
76 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
77  >();
78  addTypes<
79 #define GET_TYPEDEF_LIST
80 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
81  >();
82 
83  addInterface<OpenMPDialectFoldInterface>();
84  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
85  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
86  *getContext());
87 
88  // Attach default offload module interface to module op to access
89  // offload functionality through
90  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
91  *getContext());
92 
93  // Attach default declare target interfaces to operations which can be marked
94  // as declare target (Global Operations and Functions/Subroutines in dialects
95  // that Fortran (or other languages that lower to MLIR) translates too
96  mlir::LLVM::GlobalOp::attachInterface<
98  *getContext());
99  mlir::LLVM::LLVMFuncOp::attachInterface<
101  *getContext());
102  mlir::func::FuncOp::attachInterface<
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Parser and printer for Allocate Clause
108 //===----------------------------------------------------------------------===//
109 
110 /// Parse an allocate clause with allocators and a list of operands with types.
111 ///
112 /// allocate-operand-list :: = allocate-operand |
113 /// allocator-operand `,` allocate-operand-list
114 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
115 /// ssa-id-and-type ::= ssa-id `:` type
117  OpAsmParser &parser,
119  SmallVectorImpl<Type> &typesAllocate,
121  SmallVectorImpl<Type> &typesAllocator) {
122 
123  return parser.parseCommaSeparatedList([&]() {
125  Type type;
126  if (parser.parseOperand(operand) || parser.parseColonType(type))
127  return failure();
128  operandsAllocator.push_back(operand);
129  typesAllocator.push_back(type);
130  if (parser.parseArrow())
131  return failure();
132  if (parser.parseOperand(operand) || parser.parseColonType(type))
133  return failure();
134 
135  operandsAllocate.push_back(operand);
136  typesAllocate.push_back(type);
137  return success();
138  });
139 }
140 
141 /// Print allocate clause
143  OperandRange varsAllocate,
144  TypeRange typesAllocate,
145  OperandRange varsAllocator,
146  TypeRange typesAllocator) {
147  for (unsigned i = 0; i < varsAllocate.size(); ++i) {
148  std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
149  p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
150  p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
151  }
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // Parser and printer for a clause attribute (StringEnumAttr)
156 //===----------------------------------------------------------------------===//
157 
158 template <typename ClauseAttr>
159 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
160  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
161  StringRef enumStr;
162  SMLoc loc = parser.getCurrentLocation();
163  if (parser.parseKeyword(&enumStr))
164  return failure();
165  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
166  attr = ClauseAttr::get(parser.getContext(), *enumValue);
167  return success();
168  }
169  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
170 }
171 
172 template <typename ClauseAttr>
173 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
174  p << stringifyEnum(attr.getValue());
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Parser and printer for Linear Clause
179 //===----------------------------------------------------------------------===//
180 
181 /// linear ::= `linear` `(` linear-list `)`
182 /// linear-list := linear-val | linear-val linear-list
183 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
184 static ParseResult
187  SmallVectorImpl<Type> &types,
189  return parser.parseCommaSeparatedList([&]() {
191  Type type;
193  if (parser.parseOperand(var) || parser.parseEqual() ||
194  parser.parseOperand(stepVar) || parser.parseColonType(type))
195  return failure();
196 
197  vars.push_back(var);
198  types.push_back(type);
199  stepVars.push_back(stepVar);
200  return success();
201  });
202 }
203 
204 /// Print Linear Clause
206  ValueRange linearVars, TypeRange linearVarTypes,
207  ValueRange linearStepVars) {
208  size_t linearVarsSize = linearVars.size();
209  for (unsigned i = 0; i < linearVarsSize; ++i) {
210  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
211  p << linearVars[i];
212  if (linearStepVars.size() > i)
213  p << " = " << linearStepVars[i];
214  p << " : " << linearVars[i].getType() << separator;
215  }
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // Verifier for Nontemporal Clause
220 //===----------------------------------------------------------------------===//
221 
222 static LogicalResult
223 verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables) {
224 
225  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
226  DenseSet<Value> nontemporalItems;
227  for (const auto &it : nontemporalVariables)
228  if (!nontemporalItems.insert(it).second)
229  return op->emitOpError() << "nontemporal variable used more than once";
230 
231  return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Parser, verifier and printer for Aligned Clause
236 //===----------------------------------------------------------------------===//
237 static LogicalResult
238 verifyAlignedClause(Operation *op, std::optional<ArrayAttr> alignmentValues,
239  OperandRange alignedVariables) {
240  // Check if number of alignment values equals to number of aligned variables
241  if (!alignedVariables.empty()) {
242  if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
243  return op->emitOpError()
244  << "expected as many alignment values as aligned variables";
245  } else {
246  if (alignmentValues)
247  return op->emitOpError() << "unexpected alignment values attribute";
248  return success();
249  }
250 
251  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
252  DenseSet<Value> alignedItems;
253  for (auto it : alignedVariables)
254  if (!alignedItems.insert(it).second)
255  return op->emitOpError() << "aligned variable used more than once";
256 
257  if (!alignmentValues)
258  return success();
259 
260  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
261  for (unsigned i = 0; i < (*alignmentValues).size(); ++i) {
262  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
263  if (intAttr.getValue().sle(0))
264  return op->emitOpError() << "alignment should be greater than 0";
265  } else {
266  return op->emitOpError() << "expected integer alignment";
267  }
268  }
269 
270  return success();
271 }
272 
273 /// aligned ::= `aligned` `(` aligned-list `)`
274 /// aligned-list := aligned-val | aligned-val aligned-list
275 /// aligned-val := ssa-id-and-type `->` alignment
277  OpAsmParser &parser,
279  SmallVectorImpl<Type> &types, ArrayAttr &alignmentValues) {
280  SmallVector<Attribute> alignmentVec;
281  if (failed(parser.parseCommaSeparatedList([&]() {
282  if (parser.parseOperand(alignedItems.emplace_back()) ||
283  parser.parseColonType(types.emplace_back()) ||
284  parser.parseArrow() ||
285  parser.parseAttribute(alignmentVec.emplace_back())) {
286  return failure();
287  }
288  return success();
289  })))
290  return failure();
291  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
292  alignmentValues = ArrayAttr::get(parser.getContext(), alignments);
293  return success();
294 }
295 
296 /// Print Aligned Clause
298  ValueRange alignedVars,
299  TypeRange alignedVarTypes,
300  std::optional<ArrayAttr> alignmentValues) {
301  for (unsigned i = 0; i < alignedVars.size(); ++i) {
302  if (i != 0)
303  p << ", ";
304  p << alignedVars[i] << " : " << alignedVars[i].getType();
305  p << " -> " << (*alignmentValues)[i];
306  }
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Parser, printer and verifier for Schedule Clause
311 //===----------------------------------------------------------------------===//
312 
313 static ParseResult
315  SmallVectorImpl<SmallString<12>> &modifiers) {
316  if (modifiers.size() > 2)
317  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
318  for (const auto &mod : modifiers) {
319  // Translate the string. If it has no value, then it was not a valid
320  // modifier!
321  auto symbol = symbolizeScheduleModifier(mod);
322  if (!symbol)
323  return parser.emitError(parser.getNameLoc())
324  << " unknown modifier type: " << mod;
325  }
326 
327  // If we have one modifier that is "simd", then stick a "none" modiifer in
328  // index 0.
329  if (modifiers.size() == 1) {
330  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
331  modifiers.push_back(modifiers[0]);
332  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
333  }
334  } else if (modifiers.size() == 2) {
335  // If there are two modifier:
336  // First modifier should not be simd, second one should be simd
337  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
338  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
339  return parser.emitError(parser.getNameLoc())
340  << " incorrect modifier order";
341  }
342  return success();
343 }
344 
345 /// schedule ::= `schedule` `(` sched-list `)`
346 /// sched-list ::= sched-val | sched-val sched-list |
347 /// sched-val `,` sched-modifier
348 /// sched-val ::= sched-with-chunk | sched-wo-chunk
349 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
350 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
351 /// sched-wo-chunk ::= `auto` | `runtime`
352 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
353 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
355  OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
356  ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
357  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
358  StringRef keyword;
359  if (parser.parseKeyword(&keyword))
360  return failure();
361  std::optional<mlir::omp::ClauseScheduleKind> schedule =
362  symbolizeClauseScheduleKind(keyword);
363  if (!schedule)
364  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
365 
366  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
367  switch (*schedule) {
368  case ClauseScheduleKind::Static:
369  case ClauseScheduleKind::Dynamic:
370  case ClauseScheduleKind::Guided:
371  if (succeeded(parser.parseOptionalEqual())) {
372  chunkSize = OpAsmParser::UnresolvedOperand{};
373  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
374  return failure();
375  } else {
376  chunkSize = std::nullopt;
377  }
378  break;
379  case ClauseScheduleKind::Auto:
381  chunkSize = std::nullopt;
382  }
383 
384  // If there is a comma, we have one or more modifiers..
385  SmallVector<SmallString<12>> modifiers;
386  while (succeeded(parser.parseOptionalComma())) {
387  StringRef mod;
388  if (parser.parseKeyword(&mod))
389  return failure();
390  modifiers.push_back(mod);
391  }
392 
393  if (verifyScheduleModifiers(parser, modifiers))
394  return failure();
395 
396  if (!modifiers.empty()) {
397  SMLoc loc = parser.getCurrentLocation();
398  if (std::optional<ScheduleModifier> mod =
399  symbolizeScheduleModifier(modifiers[0])) {
400  scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
401  } else {
402  return parser.emitError(loc, "invalid schedule modifier");
403  }
404  // Only SIMD attribute is allowed here!
405  if (modifiers.size() > 1) {
406  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
407  simdModifier = UnitAttr::get(parser.getBuilder().getContext());
408  }
409  }
410 
411  return success();
412 }
413 
414 /// Print schedule clause
416  ClauseScheduleKindAttr schedAttr,
417  ScheduleModifierAttr modifier, UnitAttr simd,
418  Value scheduleChunkVar,
419  Type scheduleChunkType) {
420  p << stringifyClauseScheduleKind(schedAttr.getValue());
421  if (scheduleChunkVar)
422  p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
423  if (modifier)
424  p << ", " << stringifyScheduleModifier(modifier.getValue());
425  if (simd)
426  p << ", simd";
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // Parser, printer and verifier for ReductionVarList
431 //===----------------------------------------------------------------------===//
432 
434  OpAsmParser &parser, Region &region,
436  SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
438  SmallVector<SymbolRefAttr> reductionVec;
439  unsigned regionArgOffset = regionPrivateArgs.size();
440 
441  if (failed(
443  if (parser.parseAttribute(reductionVec.emplace_back()) ||
444  parser.parseOperand(operands.emplace_back()) ||
445  parser.parseArrow() ||
446  parser.parseArgument(regionPrivateArgs.emplace_back()) ||
447  parser.parseColonType(types.emplace_back()))
448  return failure();
449  return success();
450  })))
451  return failure();
452 
453  auto *argsBegin = regionPrivateArgs.begin();
454  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
455  argsBegin + regionArgOffset + types.size());
456  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
457  prv.type = type;
458  }
459  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
460  symbols = ArrayAttr::get(parser.getContext(), reductions);
461  return success();
462 }
463 
465  ValueRange argsSubrange,
466  StringRef clauseName, ValueRange operands,
467  TypeRange types, ArrayAttr symbols) {
468  p << clauseName << "(";
469  llvm::interleaveComma(
470  llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
471  auto [sym, op, arg, type] = t;
472  p << sym << " " << op << " -> " << arg << " : " << type;
473  });
474  p << ") ";
475 }
476 
478  OpAsmParser &parser, Region &region,
480  SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
482  llvm::SmallVectorImpl<Type> &privateVarsTypes,
483  ArrayAttr &privatizerSymbols) {
485 
486  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
487  if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
488  reductionVarTypes, reductionSymbols,
489  regionPrivateArgs)))
490  return failure();
491  }
492 
493  if (succeeded(parser.parseOptionalKeyword("private"))) {
494  if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
495  privateVarsTypes, privatizerSymbols,
496  regionPrivateArgs)))
497  return failure();
498  }
499 
500  return parser.parseRegion(region, regionPrivateArgs);
501 }
502 
503 static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
504  ValueRange reductionVarOperands,
505  TypeRange reductionVarTypes,
506  ArrayAttr reductionSymbols,
507  ValueRange privateVarOperands,
508  TypeRange privateVarTypes,
509  ArrayAttr privatizerSymbols) {
510  if (reductionSymbols) {
511  auto *argsBegin = region.front().getArguments().begin();
512  MutableArrayRef argsSubrange(argsBegin,
513  argsBegin + reductionVarTypes.size());
514  printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
515  reductionVarOperands, reductionVarTypes,
516  reductionSymbols);
517  }
518 
519  if (privatizerSymbols) {
520  auto *argsBegin = region.front().getArguments().begin();
521  MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
522  argsBegin + reductionVarOperands.size() +
523  privateVarTypes.size());
524  printClauseWithRegionArgs(p, op, argsSubrange, "private",
525  privateVarOperands, privateVarTypes,
526  privatizerSymbols);
527  }
528 
529  p.printRegion(region, /*printEntryBlockArgs=*/false);
530 }
531 
532 /// reduction-entry-list ::= reduction-entry
533 /// | reduction-entry-list `,` reduction-entry
534 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
535 static ParseResult
538  SmallVectorImpl<Type> &types,
539  ArrayAttr &redcuctionSymbols) {
540  SmallVector<SymbolRefAttr> reductionVec;
541  if (failed(parser.parseCommaSeparatedList([&]() {
542  if (parser.parseAttribute(reductionVec.emplace_back()) ||
543  parser.parseArrow() ||
544  parser.parseOperand(operands.emplace_back()) ||
545  parser.parseColonType(types.emplace_back()))
546  return failure();
547  return success();
548  })))
549  return failure();
550  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
551  redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
552  return success();
553 }
554 
555 /// Print Reduction clause
557  OperandRange reductionVars,
558  TypeRange reductionTypes,
559  std::optional<ArrayAttr> reductions) {
560  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
561  if (i != 0)
562  p << ", ";
563  p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
564  << reductionVars[i].getType();
565  }
566 }
567 
568 /// Verifies Reduction Clause
570  std::optional<ArrayAttr> reductions,
571  OperandRange reductionVars) {
572  if (!reductionVars.empty()) {
573  if (!reductions || reductions->size() != reductionVars.size())
574  return op->emitOpError()
575  << "expected as many reduction symbol references "
576  "as reduction variables";
577  } else {
578  if (reductions)
579  return op->emitOpError() << "unexpected reduction symbol references";
580  return success();
581  }
582 
583  // TODO: The followings should be done in
584  // SymbolUserOpInterface::verifySymbolUses.
585  DenseSet<Value> accumulators;
586  for (auto args : llvm::zip(reductionVars, *reductions)) {
587  Value accum = std::get<0>(args);
588 
589  if (!accumulators.insert(accum).second)
590  return op->emitOpError() << "accumulator variable used more than once";
591 
592  Type varType = accum.getType();
593  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
594  auto decl =
595  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
596  if (!decl)
597  return op->emitOpError() << "expected symbol reference " << symbolRef
598  << " to point to a reduction declaration";
599 
600  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
601  return op->emitOpError()
602  << "expected accumulator (" << varType
603  << ") to be the same type as reduction declaration ("
604  << decl.getAccumulatorType() << ")";
605  }
606 
607  return success();
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // Parser, printer and verifier for CopyPrivateVarList
612 //===----------------------------------------------------------------------===//
613 
614 /// copyprivate-entry-list ::= copyprivate-entry
615 /// | copyprivate-entry-list `,` copyprivate-entry
616 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
618  OpAsmParser &parser,
620  SmallVectorImpl<Type> &types, ArrayAttr &copyPrivateSymbols) {
621  SmallVector<SymbolRefAttr> copyPrivateFuncsVec;
622  if (failed(parser.parseCommaSeparatedList([&]() {
623  if (parser.parseOperand(operands.emplace_back()) ||
624  parser.parseArrow() ||
625  parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
626  parser.parseColonType(types.emplace_back()))
627  return failure();
628  return success();
629  })))
630  return failure();
631  SmallVector<Attribute> copyPrivateFuncs(copyPrivateFuncsVec.begin(),
632  copyPrivateFuncsVec.end());
633  copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs);
634  return success();
635 }
636 
637 /// Print CopyPrivate clause
639  OperandRange copyPrivateVars,
640  TypeRange copyPrivateTypes,
641  std::optional<ArrayAttr> copyPrivateFuncs) {
642  if (!copyPrivateFuncs.has_value())
643  return;
644  llvm::interleaveComma(
645  llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
646  [&](const auto &args) {
647  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
648  << std::get<2>(args);
649  });
650 }
651 
652 /// Verifies CopyPrivate Clause
653 static LogicalResult
655  std::optional<ArrayAttr> copyPrivateFuncs) {
656  size_t copyPrivateFuncsSize =
657  copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
658  if (copyPrivateFuncsSize != copyPrivateVars.size())
659  return op->emitOpError() << "inconsistent number of copyPrivate vars (= "
660  << copyPrivateVars.size()
661  << ") and functions (= " << copyPrivateFuncsSize
662  << "), both must be equal";
663  if (!copyPrivateFuncs.has_value())
664  return success();
665 
666  for (auto copyPrivateVarAndFunc :
667  llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
668  auto symbolRef =
669  llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
670  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
671  funcOp;
672  if (mlir::func::FuncOp mlirFuncOp =
673  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
674  symbolRef))
675  funcOp = mlirFuncOp;
676  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
677  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
678  op, symbolRef))
679  funcOp = llvmFuncOp;
680 
681  auto getNumArguments = [&] {
682  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
683  };
684 
685  auto getArgumentType = [&](unsigned i) {
686  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
687  *funcOp);
688  };
689 
690  if (!funcOp)
691  return op->emitOpError() << "expected symbol reference " << symbolRef
692  << " to point to a copy function";
693 
694  if (getNumArguments() != 2)
695  return op->emitOpError()
696  << "expected copy function " << symbolRef << " to have 2 operands";
697 
698  Type argTy = getArgumentType(0);
699  if (argTy != getArgumentType(1))
700  return op->emitOpError() << "expected copy function " << symbolRef
701  << " arguments to have the same type";
702 
703  Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
704  if (argTy != varType)
705  return op->emitOpError()
706  << "expected copy function arguments' type (" << argTy
707  << ") to be the same as copyprivate variable's type (" << varType
708  << ")";
709  }
710 
711  return success();
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Parser, printer and verifier for DependVarList
716 //===----------------------------------------------------------------------===//
717 
718 /// depend-entry-list ::= depend-entry
719 /// | depend-entry-list `,` depend-entry
720 /// depend-entry ::= depend-kind `->` ssa-id `:` type
721 static ParseResult
724  SmallVectorImpl<Type> &types, ArrayAttr &dependsArray) {
726  if (failed(parser.parseCommaSeparatedList([&]() {
727  StringRef keyword;
728  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
729  parser.parseOperand(operands.emplace_back()) ||
730  parser.parseColonType(types.emplace_back()))
731  return failure();
732  if (std::optional<ClauseTaskDepend> keywordDepend =
733  (symbolizeClauseTaskDepend(keyword)))
734  dependVec.emplace_back(
735  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
736  else
737  return failure();
738  return success();
739  })))
740  return failure();
741  SmallVector<Attribute> depends(dependVec.begin(), dependVec.end());
742  dependsArray = ArrayAttr::get(parser.getContext(), depends);
743  return success();
744 }
745 
746 /// Print Depend clause
748  OperandRange dependVars, TypeRange dependTypes,
749  std::optional<ArrayAttr> depends) {
750 
751  for (unsigned i = 0, e = depends->size(); i < e; ++i) {
752  if (i != 0)
753  p << ", ";
754  p << stringifyClauseTaskDepend(
755  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
756  .getValue())
757  << " -> " << dependVars[i] << " : " << dependTypes[i];
758  }
759 }
760 
761 /// Verifies Depend clause
763  std::optional<ArrayAttr> depends,
764  OperandRange dependVars) {
765  if (!dependVars.empty()) {
766  if (!depends || depends->size() != dependVars.size())
767  return op->emitOpError() << "expected as many depend values"
768  " as depend variables";
769  } else {
770  if (depends && !depends->empty())
771  return op->emitOpError() << "unexpected depend values";
772  return success();
773  }
774 
775  return success();
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // Parser, printer and verifier for Synchronization Hint (2.17.12)
780 //===----------------------------------------------------------------------===//
781 
782 /// Parses a Synchronization Hint clause. The value of hint is an integer
783 /// which is a combination of different hints from `omp_sync_hint_t`.
784 ///
785 /// hint-clause = `hint` `(` hint-value `)`
787  IntegerAttr &hintAttr) {
788  StringRef hintKeyword;
789  int64_t hint = 0;
790  if (succeeded(parser.parseOptionalKeyword("none"))) {
791  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
792  return success();
793  }
794  auto parseKeyword = [&]() -> ParseResult {
795  if (failed(parser.parseKeyword(&hintKeyword)))
796  return failure();
797  if (hintKeyword == "uncontended")
798  hint |= 1;
799  else if (hintKeyword == "contended")
800  hint |= 2;
801  else if (hintKeyword == "nonspeculative")
802  hint |= 4;
803  else if (hintKeyword == "speculative")
804  hint |= 8;
805  else
806  return parser.emitError(parser.getCurrentLocation())
807  << hintKeyword << " is not a valid hint";
808  return success();
809  };
810  if (parser.parseCommaSeparatedList(parseKeyword))
811  return failure();
812  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
813  return success();
814 }
815 
816 /// Prints a Synchronization Hint clause
818  IntegerAttr hintAttr) {
819  int64_t hint = hintAttr.getInt();
820 
821  if (hint == 0) {
822  p << "none";
823  return;
824  }
825 
826  // Helper function to get n-th bit from the right end of `value`
827  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
828 
829  bool uncontended = bitn(hint, 0);
830  bool contended = bitn(hint, 1);
831  bool nonspeculative = bitn(hint, 2);
832  bool speculative = bitn(hint, 3);
833 
835  if (uncontended)
836  hints.push_back("uncontended");
837  if (contended)
838  hints.push_back("contended");
839  if (nonspeculative)
840  hints.push_back("nonspeculative");
841  if (speculative)
842  hints.push_back("speculative");
843 
844  llvm::interleaveComma(hints, p);
845 }
846 
847 /// Verifies a synchronization hint clause
849 
850  // Helper function to get n-th bit from the right end of `value`
851  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
852 
853  bool uncontended = bitn(hint, 0);
854  bool contended = bitn(hint, 1);
855  bool nonspeculative = bitn(hint, 2);
856  bool speculative = bitn(hint, 3);
857 
858  if (uncontended && contended)
859  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
860  "omp_sync_hint_contended cannot be combined";
861  if (nonspeculative && speculative)
862  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
863  "omp_sync_hint_speculative cannot be combined.";
864  return success();
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // Parser, printer and verifier for Target
869 //===----------------------------------------------------------------------===//
870 
871 // Helper function to get bitwise AND of `value` and 'flag'
872 uint64_t mapTypeToBitFlag(uint64_t value,
873  llvm::omp::OpenMPOffloadMappingFlags flag) {
874  return value & llvm::to_underlying(flag);
875 }
876 
877 /// Parses a map_entries map type from a string format back into its numeric
878 /// value.
879 ///
880 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
881 /// `to` | `from` | `delete` `)` )+ `)` )
882 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
883  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
884  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
885 
886  // This simply verifies the correct keyword is read in, the
887  // keyword itself is stored inside of the operation
888  auto parseTypeAndMod = [&]() -> ParseResult {
889  StringRef mapTypeMod;
890  if (parser.parseKeyword(&mapTypeMod))
891  return failure();
892 
893  if (mapTypeMod == "always")
894  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
895 
896  if (mapTypeMod == "implicit")
897  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
898 
899  if (mapTypeMod == "close")
900  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
901 
902  if (mapTypeMod == "present")
903  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
904 
905  if (mapTypeMod == "to")
906  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
907 
908  if (mapTypeMod == "from")
909  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
910 
911  if (mapTypeMod == "tofrom")
912  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
913  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
914 
915  if (mapTypeMod == "delete")
916  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
917 
918  return success();
919  };
920 
921  if (parser.parseCommaSeparatedList(parseTypeAndMod))
922  return failure();
923 
924  mapType = parser.getBuilder().getIntegerAttr(
925  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
926  llvm::to_underlying(mapTypeBits));
927 
928  return success();
929 }
930 
931 /// Prints a map_entries map type from its numeric value out into its string
932 /// format.
934  IntegerAttr mapType) {
935  uint64_t mapTypeBits = mapType.getUInt();
936 
937  bool emitAllocRelease = true;
939 
940  // handling of always, close, present placed at the beginning of the string
941  // to aid readability
942  if (mapTypeToBitFlag(mapTypeBits,
943  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
944  mapTypeStrs.push_back("always");
945  if (mapTypeToBitFlag(mapTypeBits,
946  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
947  mapTypeStrs.push_back("implicit");
948  if (mapTypeToBitFlag(mapTypeBits,
949  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
950  mapTypeStrs.push_back("close");
951  if (mapTypeToBitFlag(mapTypeBits,
952  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
953  mapTypeStrs.push_back("present");
954 
955  // special handling of to/from/tofrom/delete and release/alloc, release +
956  // alloc are the abscense of one of the other flags, whereas tofrom requires
957  // both the to and from flag to be set.
958  bool to = mapTypeToBitFlag(mapTypeBits,
959  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
960  bool from = mapTypeToBitFlag(
961  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
962  if (to && from) {
963  emitAllocRelease = false;
964  mapTypeStrs.push_back("tofrom");
965  } else if (from) {
966  emitAllocRelease = false;
967  mapTypeStrs.push_back("from");
968  } else if (to) {
969  emitAllocRelease = false;
970  mapTypeStrs.push_back("to");
971  }
972  if (mapTypeToBitFlag(mapTypeBits,
973  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
974  emitAllocRelease = false;
975  mapTypeStrs.push_back("delete");
976  }
977  if (emitAllocRelease)
978  mapTypeStrs.push_back("exit_release_or_enter_alloc");
979 
980  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
981  p << mapTypeStrs[i];
982  if (i + 1 < mapTypeStrs.size()) {
983  p << ", ";
984  }
985  }
986 }
987 
988 static ParseResult
991  SmallVectorImpl<Type> &mapOperandTypes) {
994  Type argType;
995  auto parseEntries = [&]() -> ParseResult {
996  if (parser.parseOperand(arg) || parser.parseArrow() ||
997  parser.parseOperand(blockArg))
998  return failure();
999  mapOperands.push_back(arg);
1000  return success();
1001  };
1002 
1003  auto parseTypes = [&]() -> ParseResult {
1004  if (parser.parseType(argType))
1005  return failure();
1006  mapOperandTypes.push_back(argType);
1007  return success();
1008  };
1009 
1010  if (parser.parseCommaSeparatedList(parseEntries))
1011  return failure();
1012 
1013  if (parser.parseColon())
1014  return failure();
1015 
1016  if (parser.parseCommaSeparatedList(parseTypes))
1017  return failure();
1018 
1019  return success();
1020 }
1021 
1023  OperandRange mapOperands,
1024  TypeRange mapOperandTypes) {
1025  auto &region = op->getRegion(0);
1026  unsigned argIndex = 0;
1027 
1028  for (const auto &mapOp : mapOperands) {
1029  const auto &blockArg = region.front().getArgument(argIndex);
1030  p << mapOp << " -> " << blockArg;
1031  argIndex++;
1032  if (argIndex < mapOperands.size())
1033  p << ", ";
1034  }
1035  p << " : ";
1036 
1037  argIndex = 0;
1038  for (const auto &mapType : mapOperandTypes) {
1039  p << mapType;
1040  argIndex++;
1041  if (argIndex < mapOperands.size())
1042  p << ", ";
1043  }
1044 }
1045 
1047  VariableCaptureKindAttr mapCaptureType) {
1048  std::string typeCapStr;
1049  llvm::raw_string_ostream typeCap(typeCapStr);
1050  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1051  typeCap << "ByRef";
1052  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1053  typeCap << "ByCopy";
1054  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1055  typeCap << "VLAType";
1056  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1057  typeCap << "This";
1058  p << typeCap.str();
1059 }
1060 
1062  VariableCaptureKindAttr &mapCapture) {
1063  StringRef mapCaptureKey;
1064  if (parser.parseKeyword(&mapCaptureKey))
1065  return failure();
1066 
1067  if (mapCaptureKey == "This")
1069  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1070  if (mapCaptureKey == "ByRef")
1072  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1073  if (mapCaptureKey == "ByCopy")
1075  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1076  if (mapCaptureKey == "VLAType")
1078  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1079 
1080  return success();
1081 }
1082 
1086 
1087  for (auto mapOp : mapOperands) {
1088  if (!mapOp.getDefiningOp())
1089  emitError(op->getLoc(), "missing map operation");
1090 
1091  if (auto mapInfoOp =
1092  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1093  if (!mapInfoOp.getMapType().has_value())
1094  emitError(op->getLoc(), "missing map type for map operand");
1095 
1096  if (!mapInfoOp.getMapCaptureType().has_value())
1097  emitError(op->getLoc(), "missing map capture type for map operand");
1098 
1099  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1100 
1101  bool to = mapTypeToBitFlag(
1102  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1103  bool from = mapTypeToBitFlag(
1104  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1105  bool del = mapTypeToBitFlag(
1106  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1107 
1108  bool always = mapTypeToBitFlag(
1109  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1110  bool close = mapTypeToBitFlag(
1111  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1112  bool implicit = mapTypeToBitFlag(
1113  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1114 
1115  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1116  return emitError(op->getLoc(),
1117  "to, from, tofrom and alloc map types are permitted");
1118 
1119  if (isa<TargetEnterDataOp>(op) && (from || del))
1120  return emitError(op->getLoc(), "to and alloc map types are permitted");
1121 
1122  if (isa<TargetExitDataOp>(op) && to)
1123  return emitError(op->getLoc(),
1124  "from, release and delete map types are permitted");
1125 
1126  if (isa<TargetUpdateOp>(op)) {
1127  if (del) {
1128  return emitError(op->getLoc(),
1129  "at least one of to or from map types must be "
1130  "specified, other map types are not permitted");
1131  }
1132 
1133  if (!to && !from) {
1134  return emitError(op->getLoc(),
1135  "at least one of to or from map types must be "
1136  "specified, other map types are not permitted");
1137  }
1138 
1139  auto updateVar = mapInfoOp.getVarPtr();
1140 
1141  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1142  (from && updateToVars.contains(updateVar))) {
1143  return emitError(
1144  op->getLoc(),
1145  "either to or from map types can be specified, not both");
1146  }
1147 
1148  if (always || close || implicit) {
1149  return emitError(
1150  op->getLoc(),
1151  "present, mapper and iterator map type modifiers are permitted");
1152  }
1153 
1154  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1155  }
1156  } else {
1157  emitError(op->getLoc(), "map argument is not a map entry operation");
1158  }
1159  }
1160 
1161  return success();
1162 }
1163 
1165  if (getMapOperands().empty() && getUseDevicePtr().empty() &&
1166  getUseDeviceAddr().empty()) {
1167  return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or "
1168  "useDeviceAddr operand must be present");
1169  }
1170  return verifyMapClause(*this, getMapOperands());
1171 }
1172 
1174  LogicalResult verifyDependVars =
1175  verifyDependVarList(*this, getDepends(), getDependVars());
1176  return failed(verifyDependVars) ? verifyDependVars
1177  : verifyMapClause(*this, getMapOperands());
1178 }
1179 
1181  LogicalResult verifyDependVars =
1182  verifyDependVarList(*this, getDepends(), getDependVars());
1183  return failed(verifyDependVars) ? verifyDependVars
1184  : verifyMapClause(*this, getMapOperands());
1185 }
1186 
1188  LogicalResult verifyDependVars =
1189  verifyDependVarList(*this, getDepends(), getDependVars());
1190  return failed(verifyDependVars) ? verifyDependVars
1191  : verifyMapClause(*this, getMapOperands());
1192 }
1193 
1195  LogicalResult verifyDependVars =
1196  verifyDependVarList(*this, getDepends(), getDependVars());
1197  return failed(verifyDependVars) ? verifyDependVars
1198  : verifyMapClause(*this, getMapOperands());
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // ParallelOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1206  ArrayRef<NamedAttribute> attributes) {
1207  ParallelOp::build(
1208  builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
1209  /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
1210  /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1211  /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1212  /*privatizers=*/nullptr, /*byref=*/false);
1213  state.addAttributes(attributes);
1214 }
1215 
1216 template <typename OpType>
1218  auto privateVars = op.getPrivateVars();
1219  auto privatizers = op.getPrivatizersAttr();
1220 
1221  if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1222  return success();
1223 
1224  auto numPrivateVars = privateVars.size();
1225  auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1226 
1227  if (numPrivateVars != numPrivatizers)
1228  return op.emitError() << "inconsistent number of private variables and "
1229  "privatizer op symbols, private vars: "
1230  << numPrivateVars
1231  << " vs. privatizer op symbols: " << numPrivatizers;
1232 
1233  for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1234  Type varType = std::get<0>(privateVarInfo).getType();
1235  SymbolRefAttr privatizerSym =
1236  std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
1237  PrivateClauseOp privatizerOp =
1238  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1239  privatizerSym);
1240 
1241  if (privatizerOp == nullptr)
1242  return op.emitError() << "failed to lookup privatizer op with symbol: '"
1243  << privatizerSym << "'";
1244 
1245  Type privatizerType = privatizerOp.getType();
1246 
1247  if (varType != privatizerType)
1248  return op.emitError()
1249  << "type mismatch between a "
1250  << (privatizerOp.getDataSharingType() ==
1251  DataSharingClauseType::Private
1252  ? "private"
1253  : "firstprivate")
1254  << " variable and its privatizer op, var type: " << varType
1255  << " vs. privatizer op type: " << privatizerType;
1256  }
1257 
1258  return success();
1259 }
1260 
1262  if (getAllocateVars().size() != getAllocatorsVars().size())
1263  return emitError(
1264  "expected equal sizes for allocate and allocator variables");
1265 
1266  if (failed(verifyPrivateVarList(*this)))
1267  return failure();
1268 
1269  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1270 }
1271 
1272 //===----------------------------------------------------------------------===//
1273 // TeamsOp
1274 //===----------------------------------------------------------------------===//
1275 
1277  while ((op = op->getParentOp()))
1278  if (isa<OpenMPDialect>(op->getDialect()))
1279  return false;
1280  return true;
1281 }
1282 
1284  // Check parent region
1285  // TODO If nested inside of a target region, also check that it does not
1286  // contain any statements, declarations or directives other than this
1287  // omp.teams construct. The issue is how to support the initialization of
1288  // this operation's own arguments (allow SSA values across omp.target?).
1289  Operation *op = getOperation();
1290  if (!isa<TargetOp>(op->getParentOp()) &&
1292  return emitError("expected to be nested inside of omp.target or not nested "
1293  "in any OpenMP dialect operations");
1294 
1295  // Check for num_teams clause restrictions
1296  if (auto numTeamsLowerBound = getNumTeamsLower()) {
1297  auto numTeamsUpperBound = getNumTeamsUpper();
1298  if (!numTeamsUpperBound)
1299  return emitError("expected num_teams upper bound to be defined if the "
1300  "lower bound is defined");
1301  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1302  return emitError(
1303  "expected num_teams upper bound and lower bound to be the same type");
1304  }
1305 
1306  // Check for allocate clause restrictions
1307  if (getAllocateVars().size() != getAllocatorsVars().size())
1308  return emitError(
1309  "expected equal sizes for allocate and allocator variables");
1310 
1311  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1312 }
1313 
1314 //===----------------------------------------------------------------------===//
1315 // Verifier for SectionsOp
1316 //===----------------------------------------------------------------------===//
1317 
1319  if (getAllocateVars().size() != getAllocatorsVars().size())
1320  return emitError(
1321  "expected equal sizes for allocate and allocator variables");
1322 
1323  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1324 }
1325 
1326 LogicalResult SectionsOp::verifyRegions() {
1327  for (auto &inst : *getRegion().begin()) {
1328  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1329  return emitOpError()
1330  << "expected omp.section op or terminator op inside region";
1331  }
1332  }
1333 
1334  return success();
1335 }
1336 
1338  // Check for allocate clause restrictions
1339  if (getAllocateVars().size() != getAllocatorsVars().size())
1340  return emitError(
1341  "expected equal sizes for allocate and allocator variables");
1342 
1343  return verifyCopyPrivateVarList(*this, getCopyprivateVars(),
1344  getCopyprivateFuncs());
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // WsloopOp
1349 //===----------------------------------------------------------------------===//
1350 
1351 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
1352 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1353 /// steps := `step` `(`ssa-id-list`)`
1355 parseWsloop(OpAsmParser &parser, Region &region,
1359  SmallVectorImpl<Type> &loopVarTypes,
1361  SmallVectorImpl<Type> &reductionTypes, ArrayAttr &reductionSymbols,
1362  UnitAttr &inclusive) {
1363 
1364  // Parse an optional reduction clause
1366  bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
1368  parser, region, reductionOperands, reductionTypes,
1369  reductionSymbols, privates));
1370 
1371  if (parser.parseKeyword("for"))
1372  return failure();
1373 
1374  // Parse an opening `(` followed by induction variables followed by `)`
1376  Type loopVarType;
1378  parser.parseColonType(loopVarType) ||
1379  // Parse loop bounds.
1380  parser.parseEqual() ||
1381  parser.parseOperandList(lowerBound, ivs.size(),
1383  parser.parseKeyword("to") ||
1384  parser.parseOperandList(upperBound, ivs.size(),
1386  return failure();
1387 
1388  if (succeeded(parser.parseOptionalKeyword("inclusive")))
1389  inclusive = UnitAttr::get(parser.getBuilder().getContext());
1390 
1391  // Parse step values.
1392  if (parser.parseKeyword("step") ||
1393  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
1394  return failure();
1395 
1396  // Now parse the body.
1397  loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
1398  for (auto &iv : ivs)
1399  iv.type = loopVarType;
1400 
1401  SmallVector<OpAsmParser::Argument> regionArgs{ivs};
1402  if (hasReduction)
1403  llvm::copy(privates, std::back_inserter(regionArgs));
1404 
1405  return parser.parseRegion(region, regionArgs);
1406 }
1407 
1409  ValueRange lowerBound, ValueRange upperBound, ValueRange steps,
1410  TypeRange loopVarTypes, ValueRange reductionOperands,
1411  TypeRange reductionTypes, ArrayAttr reductionSymbols,
1412  UnitAttr inclusive) {
1413  if (reductionSymbols) {
1414  auto reductionArgs =
1415  region.front().getArguments().drop_front(loopVarTypes.size());
1416  printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1417  reductionOperands, reductionTypes,
1418  reductionSymbols);
1419  }
1420 
1421  p << " for ";
1422  auto args = region.front().getArguments().drop_back(reductionOperands.size());
1423  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
1424  << ") to (" << upperBound << ") ";
1425  if (inclusive)
1426  p << "inclusive ";
1427  p << "step (" << steps << ") ";
1428  p.printRegion(region, /*printEntryBlockArgs=*/false);
1429 }
1430 
1431 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
1432 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1433 /// steps := `step` `(`ssa-id-list`)`
1439  SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
1440  // Parse an opening `(` followed by induction variables followed by `)`
1442  Type loopVarType;
1444  parser.parseColonType(loopVarType) ||
1445  // Parse loop bounds.
1446  parser.parseEqual() ||
1447  parser.parseOperandList(lowerBound, ivs.size(),
1449  parser.parseKeyword("to") ||
1450  parser.parseOperandList(upperBound, ivs.size(),
1452  return failure();
1453 
1454  if (succeeded(parser.parseOptionalKeyword("inclusive")))
1455  inclusive = UnitAttr::get(parser.getBuilder().getContext());
1456 
1457  // Parse step values.
1458  if (parser.parseKeyword("step") ||
1459  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
1460  return failure();
1461 
1462  // Now parse the body.
1463  loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
1464  for (auto &iv : ivs)
1465  iv.type = loopVarType;
1466 
1467  return parser.parseRegion(region, ivs);
1468 }
1469 
1471  ValueRange lowerBound, ValueRange upperBound,
1472  ValueRange steps, TypeRange loopVarTypes,
1473  UnitAttr inclusive) {
1474  auto args = region.front().getArguments();
1475  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
1476  << ") to (" << upperBound << ") ";
1477  if (inclusive)
1478  p << "inclusive ";
1479  p << "step (" << steps << ") ";
1480  p.printRegion(region, /*printEntryBlockArgs=*/false);
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // Verifier for Simd construct [2.9.3.1]
1485 //===----------------------------------------------------------------------===//
1486 
1488  if (this->getLowerBound().empty()) {
1489  return emitOpError() << "empty lowerbound for simd loop operation";
1490  }
1491  if (this->getSimdlen().has_value() && this->getSafelen().has_value() &&
1492  this->getSimdlen().value() > this->getSafelen().value()) {
1493  return emitOpError()
1494  << "simdlen clause and safelen clause are both present, but the "
1495  "simdlen value is not less than or equal to safelen value";
1496  }
1497  if (verifyAlignedClause(*this, this->getAlignmentValues(),
1498  this->getAlignedVars())
1499  .failed())
1500  return failure();
1501  if (verifyNontemporalClause(*this, this->getNontemporalVars()).failed())
1502  return failure();
1503  return success();
1504 }
1505 
1506 //===----------------------------------------------------------------------===//
1507 // Verifier for Distribute construct [2.9.4.1]
1508 //===----------------------------------------------------------------------===//
1509 
1511  if (this->getChunkSize() && !this->getDistScheduleStatic())
1512  return emitOpError() << "chunk size set without "
1513  "dist_schedule_static being present";
1514 
1515  if (getAllocateVars().size() != getAllocatorsVars().size())
1516  return emitError(
1517  "expected equal sizes for allocate and allocator variables");
1518 
1519  return success();
1520 }
1521 
1522 //===----------------------------------------------------------------------===//
1523 // ReductionOp
1524 //===----------------------------------------------------------------------===//
1525 
1527  Region &region) {
1528  if (parser.parseOptionalKeyword("atomic"))
1529  return success();
1530  return parser.parseRegion(region);
1531 }
1532 
1534  DeclareReductionOp op, Region &region) {
1535  if (region.empty())
1536  return;
1537  printer << "atomic ";
1538  printer.printRegion(region);
1539 }
1540 
1541 LogicalResult DeclareReductionOp::verifyRegions() {
1542  if (getInitializerRegion().empty())
1543  return emitOpError() << "expects non-empty initializer region";
1544  Block &initializerEntryBlock = getInitializerRegion().front();
1545  if (initializerEntryBlock.getNumArguments() != 1 ||
1546  initializerEntryBlock.getArgument(0).getType() != getType()) {
1547  return emitOpError() << "expects initializer region with one argument "
1548  "of the reduction type";
1549  }
1550 
1551  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1552  if (yieldOp.getResults().size() != 1 ||
1553  yieldOp.getResults().getTypes()[0] != getType())
1554  return emitOpError() << "expects initializer region to yield a value "
1555  "of the reduction type";
1556  }
1557 
1558  if (getReductionRegion().empty())
1559  return emitOpError() << "expects non-empty reduction region";
1560  Block &reductionEntryBlock = getReductionRegion().front();
1561  if (reductionEntryBlock.getNumArguments() != 2 ||
1562  reductionEntryBlock.getArgumentTypes()[0] !=
1563  reductionEntryBlock.getArgumentTypes()[1] ||
1564  reductionEntryBlock.getArgumentTypes()[0] != getType())
1565  return emitOpError() << "expects reduction region with two arguments of "
1566  "the reduction type";
1567  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1568  if (yieldOp.getResults().size() != 1 ||
1569  yieldOp.getResults().getTypes()[0] != getType())
1570  return emitOpError() << "expects reduction region to yield a value "
1571  "of the reduction type";
1572  }
1573 
1574  if (getAtomicReductionRegion().empty())
1575  return success();
1576 
1577  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1578  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1579  atomicReductionEntryBlock.getArgumentTypes()[0] !=
1580  atomicReductionEntryBlock.getArgumentTypes()[1])
1581  return emitOpError() << "expects atomic reduction region with two "
1582  "arguments of the same type";
1583  auto ptrType = llvm::dyn_cast<PointerLikeType>(
1584  atomicReductionEntryBlock.getArgumentTypes()[0]);
1585  if (!ptrType ||
1586  (ptrType.getElementType() && ptrType.getElementType() != getType()))
1587  return emitOpError() << "expects atomic reduction region arguments to "
1588  "be accumulators containing the reduction type";
1589  return success();
1590 }
1591 
1593  auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
1594  if (!op)
1595  return emitOpError() << "must be used within an operation supporting "
1596  "reduction clause interface";
1597  while (op) {
1598  for (const auto &var :
1599  cast<ReductionClauseInterface>(op).getAllReductionVars())
1600  if (var == getAccumulator())
1601  return success();
1602  op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
1603  }
1604  return emitOpError() << "the accumulator is not used by the parent";
1605 }
1606 
1607 //===----------------------------------------------------------------------===//
1608 // TaskOp
1609 //===----------------------------------------------------------------------===//
1611  LogicalResult verifyDependVars =
1612  verifyDependVarList(*this, getDepends(), getDependVars());
1613  return failed(verifyDependVars)
1614  ? verifyDependVars
1615  : verifyReductionVarList(*this, getInReductions(),
1616  getInReductionVars());
1617 }
1618 
1619 //===----------------------------------------------------------------------===//
1620 // TaskgroupOp
1621 //===----------------------------------------------------------------------===//
1623  return verifyReductionVarList(*this, getTaskReductions(),
1624  getTaskReductionVars());
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // TaskloopOp
1629 //===----------------------------------------------------------------------===//
1630 SmallVector<Value> TaskloopOp::getAllReductionVars() {
1631  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
1632  getInReductionVars().end());
1633  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1634  getReductionVars().end());
1635  return allReductionNvars;
1636 }
1637 
1639  if (getAllocateVars().size() != getAllocatorsVars().size())
1640  return emitError(
1641  "expected equal sizes for allocate and allocator variables");
1642  if (failed(
1643  verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
1644  failed(verifyReductionVarList(*this, getInReductions(),
1645  getInReductionVars())))
1646  return failure();
1647 
1648  if (!getReductionVars().empty() && getNogroup())
1649  return emitError("if a reduction clause is present on the taskloop "
1650  "directive, the nogroup clause must not be specified");
1651  for (auto var : getReductionVars()) {
1652  if (llvm::is_contained(getInReductionVars(), var))
1653  return emitError("the same list item cannot appear in both a reduction "
1654  "and an in_reduction clause");
1655  }
1656 
1657  if (getGrainSize() && getNumTasks()) {
1658  return emitError(
1659  "the grainsize clause and num_tasks clause are mutually exclusive and "
1660  "may not appear on the same taskloop directive");
1661  }
1662  return success();
1663 }
1664 
1665 //===----------------------------------------------------------------------===//
1666 // WsloopOp
1667 //===----------------------------------------------------------------------===//
1668 
1669 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1670  ValueRange lowerBound, ValueRange upperBound,
1671  ValueRange step, ArrayRef<NamedAttribute> attributes) {
1672  build(builder, state, lowerBound, upperBound, step,
1673  /*linear_vars=*/ValueRange(),
1674  /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
1675  /*reductions=*/nullptr, /*schedule_val=*/nullptr,
1676  /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
1677  /*simd_modifier=*/false, /*nowait=*/false, /*byref=*/false,
1678  /*ordered_val=*/nullptr,
1679  /*order_val=*/nullptr, /*inclusive=*/false);
1680  state.addAttributes(attributes);
1681 }
1682 
1684  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1685 }
1686 
1687 //===----------------------------------------------------------------------===//
1688 // Verifier for critical construct (2.17.1)
1689 //===----------------------------------------------------------------------===//
1690 
1692  return verifySynchronizationHint(*this, getHintVal());
1693 }
1694 
1695 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1696  if (getNameAttr()) {
1697  SymbolRefAttr symbolRef = getNameAttr();
1698  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
1699  *this, symbolRef);
1700  if (!decl) {
1701  return emitOpError() << "expected symbol reference " << symbolRef
1702  << " to point to a critical declaration";
1703  }
1704  }
1705 
1706  return success();
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // Verifier for ordered construct
1711 //===----------------------------------------------------------------------===//
1712 
1714  auto container = (*this)->getParentOfType<WsloopOp>();
1715  if (!container || !container.getOrderedValAttr() ||
1716  container.getOrderedValAttr().getInt() == 0)
1717  return emitOpError() << "ordered depend directive must be closely "
1718  << "nested inside a worksharing-loop with ordered "
1719  << "clause with parameter present";
1720 
1721  if (container.getOrderedValAttr().getInt() != (int64_t)*getNumLoopsVal())
1722  return emitOpError() << "number of variables in depend clause does not "
1723  << "match number of iteration variables in the "
1724  << "doacross loop";
1725 
1726  return success();
1727 }
1728 
1730  // TODO: The code generation for ordered simd directive is not supported yet.
1731  if (getSimd())
1732  return failure();
1733 
1734  if (auto container = (*this)->getParentOfType<WsloopOp>()) {
1735  if (!container.getOrderedValAttr() ||
1736  container.getOrderedValAttr().getInt() != 0)
1737  return emitOpError() << "ordered region must be closely nested inside "
1738  << "a worksharing-loop region with an ordered "
1739  << "clause without parameter present";
1740  }
1741 
1742  return success();
1743 }
1744 
1745 //===----------------------------------------------------------------------===//
1746 // Verifier for AtomicReadOp
1747 //===----------------------------------------------------------------------===//
1748 
1750  if (verifyCommon().failed())
1751  return mlir::failure();
1752 
1753  if (auto mo = getMemoryOrderVal()) {
1754  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1755  *mo == ClauseMemoryOrderKind::Release) {
1756  return emitError(
1757  "memory-order must not be acq_rel or release for atomic reads");
1758  }
1759  }
1760  return verifySynchronizationHint(*this, getHintVal());
1761 }
1762 
1763 //===----------------------------------------------------------------------===//
1764 // Verifier for AtomicWriteOp
1765 //===----------------------------------------------------------------------===//
1766 
1768  if (verifyCommon().failed())
1769  return mlir::failure();
1770 
1771  if (auto mo = getMemoryOrderVal()) {
1772  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1773  *mo == ClauseMemoryOrderKind::Acquire) {
1774  return emitError(
1775  "memory-order must not be acq_rel or acquire for atomic writes");
1776  }
1777  }
1778  return verifySynchronizationHint(*this, getHintVal());
1779 }
1780 
1781 //===----------------------------------------------------------------------===//
1782 // Verifier for AtomicUpdateOp
1783 //===----------------------------------------------------------------------===//
1784 
1785 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1786  PatternRewriter &rewriter) {
1787  if (op.isNoOp()) {
1788  rewriter.eraseOp(op);
1789  return success();
1790  }
1791  if (Value writeVal = op.getWriteOpVal()) {
1792  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
1793  op.getHintValAttr(),
1794  op.getMemoryOrderValAttr());
1795  return success();
1796  }
1797  return failure();
1798 }
1799 
1801  if (verifyCommon().failed())
1802  return mlir::failure();
1803 
1804  if (auto mo = getMemoryOrderVal()) {
1805  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
1806  *mo == ClauseMemoryOrderKind::Acquire) {
1807  return emitError(
1808  "memory-order must not be acq_rel or acquire for atomic updates");
1809  }
1810  }
1811 
1812  return verifySynchronizationHint(*this, getHintVal());
1813 }
1814 
1815 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
1816 
1817 //===----------------------------------------------------------------------===//
1818 // Verifier for AtomicCaptureOp
1819 //===----------------------------------------------------------------------===//
1820 
1821 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1822  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1823  return op;
1824  return dyn_cast<AtomicReadOp>(getSecondOp());
1825 }
1826 
1827 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1828  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1829  return op;
1830  return dyn_cast<AtomicWriteOp>(getSecondOp());
1831 }
1832 
1833 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1834  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1835  return op;
1836  return dyn_cast<AtomicUpdateOp>(getSecondOp());
1837 }
1838 
1840  return verifySynchronizationHint(*this, getHintVal());
1841 }
1842 
1843 LogicalResult AtomicCaptureOp::verifyRegions() {
1844  if (verifyRegionsCommon().failed())
1845  return mlir::failure();
1846 
1847  if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
1848  return emitOpError(
1849  "operations inside capture region must not have hint clause");
1850 
1851  if (getFirstOp()->getAttr("memory_order_val") ||
1852  getSecondOp()->getAttr("memory_order_val"))
1853  return emitOpError(
1854  "operations inside capture region must not have memory_order clause");
1855  return success();
1856 }
1857 
1858 //===----------------------------------------------------------------------===//
1859 // Verifier for CancelOp
1860 //===----------------------------------------------------------------------===//
1861 
1863  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1864  Operation *parentOp = (*this)->getParentOp();
1865 
1866  if (!parentOp) {
1867  return emitOpError() << "must be used within a region supporting "
1868  "cancel directive";
1869  }
1870 
1871  if ((cct == ClauseCancellationConstructType::Parallel) &&
1872  !isa<ParallelOp>(parentOp)) {
1873  return emitOpError() << "cancel parallel must appear "
1874  << "inside a parallel region";
1875  }
1876  if (cct == ClauseCancellationConstructType::Loop) {
1877  if (!isa<WsloopOp>(parentOp)) {
1878  return emitOpError() << "cancel loop must appear "
1879  << "inside a worksharing-loop region";
1880  }
1881  if (cast<WsloopOp>(parentOp).getNowaitAttr()) {
1882  return emitError() << "A worksharing construct that is canceled "
1883  << "must not have a nowait clause";
1884  }
1885  if (cast<WsloopOp>(parentOp).getOrderedValAttr()) {
1886  return emitError() << "A worksharing construct that is canceled "
1887  << "must not have an ordered clause";
1888  }
1889 
1890  } else if (cct == ClauseCancellationConstructType::Sections) {
1891  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1892  return emitOpError() << "cancel sections must appear "
1893  << "inside a sections region";
1894  }
1895  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
1896  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
1897  return emitError() << "A sections construct that is canceled "
1898  << "must not have a nowait clause";
1899  }
1900  }
1901  // TODO : Add more when we support taskgroup.
1902  return success();
1903 }
1904 //===----------------------------------------------------------------------===//
1905 // Verifier for CancelOp
1906 //===----------------------------------------------------------------------===//
1907 
1909  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
1910  Operation *parentOp = (*this)->getParentOp();
1911 
1912  if (!parentOp) {
1913  return emitOpError() << "must be used within a region supporting "
1914  "cancellation point directive";
1915  }
1916 
1917  if ((cct == ClauseCancellationConstructType::Parallel) &&
1918  !(isa<ParallelOp>(parentOp))) {
1919  return emitOpError() << "cancellation point parallel must appear "
1920  << "inside a parallel region";
1921  }
1922  if ((cct == ClauseCancellationConstructType::Loop) &&
1923  !isa<WsloopOp>(parentOp)) {
1924  return emitOpError() << "cancellation point loop must appear "
1925  << "inside a worksharing-loop region";
1926  }
1927  if ((cct == ClauseCancellationConstructType::Sections) &&
1928  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1929  return emitOpError() << "cancellation point sections must appear "
1930  << "inside a sections region";
1931  }
1932  // TODO : Add more when we support taskgroup.
1933  return success();
1934 }
1935 
1936 //===----------------------------------------------------------------------===//
1937 // MapBoundsOp
1938 //===----------------------------------------------------------------------===//
1939 
1941  auto extent = getExtent();
1942  auto upperbound = getUpperBound();
1943  if (!extent && !upperbound)
1944  return emitError("expected extent or upperbound.");
1945  return success();
1946 }
1947 
1948 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1949  TypeRange /*result_types*/, StringAttr symName,
1950  TypeAttr type) {
1951  PrivateClauseOp::build(
1952  odsBuilder, odsState, symName, type,
1954  DataSharingClauseType::Private));
1955 }
1956 
1958  Type symType = getType();
1959 
1960  auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
1961  if (!terminator->getBlock()->getSuccessors().empty())
1962  return success();
1963 
1964  if (!llvm::isa<YieldOp>(terminator))
1965  return mlir::emitError(terminator->getLoc())
1966  << "expected exit block terminator to be an `omp.yield` op.";
1967 
1968  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
1969  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
1970 
1971  if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
1972  return success();
1973 
1974  auto error = mlir::emitError(yieldOp.getLoc())
1975  << "Invalid yielded value. Expected type: " << symType
1976  << ", got: ";
1977 
1978  if (yieldedTypes.empty())
1979  error << "None";
1980  else
1981  error << yieldedTypes;
1982 
1983  return error;
1984  };
1985 
1986  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
1987  StringRef regionName) -> LogicalResult {
1988  assert(!region.empty());
1989 
1990  if (region.getNumArguments() != expectedNumArgs)
1991  return mlir::emitError(region.getLoc())
1992  << "`" << regionName << "`: "
1993  << "expected " << expectedNumArgs
1994  << " region arguments, got: " << region.getNumArguments();
1995 
1996  for (Block &block : region) {
1997  // MLIR will verify the absence of the terminator for us.
1998  if (!block.mightHaveTerminator())
1999  continue;
2000 
2001  if (failed(verifyTerminator(block.getTerminator())))
2002  return failure();
2003  }
2004 
2005  return success();
2006  };
2007 
2008  if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc")))
2009  return failure();
2010 
2011  DataSharingClauseType dsType = getDataSharingType();
2012 
2013  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2014  return emitError("`private` clauses require only an `alloc` region.");
2015 
2016  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2017  return emitError(
2018  "`firstprivate` clauses require both `alloc` and `copy` regions.");
2019 
2020  if (dsType == DataSharingClauseType::FirstPrivate &&
2021  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy")))
2022  return failure();
2023 
2024  return success();
2025 }
2026 
2027 #define GET_ATTRDEF_CLASSES
2028 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2029 
2030 #define GET_OP_CLASSES
2031 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2032 
2033 #define GET_TYPEDEF_CLASSES
2034 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:716
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:708
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
ParseResult parseWsloop(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperBound, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &loopVarTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, ArrayAttr &reductionSymbols, UnitAttr &inclusive)
loop-control ::= ( ssa-id-list ) : type = loop-bounds loop-bounds := ( ssa-id-list ) to ( ssa-id-list...
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars)
Verifies Reduction Clause.
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
void printWsloop(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, ValueRange reductionOperands, TypeRange reductionTypes, ArrayAttr reductionSymbols, UnitAttr inclusive)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printAtomicReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region &region)
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, ArrayAttr symbols)
static ParseResult parseCopyPrivateVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &copyPrivateSymbols)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op, OperandRange copyPrivateVars, TypeRange copyPrivateTypes, std::optional< ArrayAttr > copyPrivateFuncs)
Print CopyPrivate clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &redcuctionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerBound, ValueRange upperBound, ValueRange steps, TypeRange loopVarTypes, UnitAttr inclusive)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region &region)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange reductionVarOperands, TypeRange reductionVarTypes, ArrayAttr reductionSymbols, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static bool opInGlobalImplicitParallelRegion(Operation *op)
ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs)
static LogicalResult verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars, std::optional< ArrayAttr > copyPrivateFuncs)
Verifies CopyPrivate Clause.
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parseParallelRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVarOperands, SmallVectorImpl< Type > &reductionVarTypes, ArrayAttr &reductionSymbols, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVarOperands, llvm::SmallVectorImpl< Type > &privateVarsTypes, ArrayAttr &privatizerSymbols)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
Define a fold interface to allow for dialects to control specific aspects of the folding behavior for...
DialectFoldInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:209
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.