MLIR  19.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/STLForwardCompat.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Frontend/OpenMP/OMPConstants.h"
33 #include <cstddef>
34 #include <iterator>
35 #include <optional>
36 
37 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
38 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
39 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
42 
43 using namespace mlir;
44 using namespace mlir::omp;
45 
46 static ArrayAttr makeArrayAttr(MLIRContext *context,
48  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
49 }
50 
51 namespace {
52 struct MemRefPointerLikeModel
53  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
54  MemRefType> {
55  Type getElementType(Type pointer) const {
56  return llvm::cast<MemRefType>(pointer).getElementType();
57  }
58 };
59 
60 struct LLVMPointerPointerLikeModel
61  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
62  LLVM::LLVMPointerType> {
63  Type getElementType(Type pointer) const { return Type(); }
64 };
65 
66 struct OpenMPDialectFoldInterface : public DialectFoldInterface {
68 
69  bool shouldMaterializeInto(Region *region) const final {
70  // Avoid folding constants across target regions
71  return isa<TargetOp>(region->getParentOp());
72  }
73 };
74 } // namespace
75 
76 void OpenMPDialect::initialize() {
77  addOperations<
78 #define GET_OP_LIST
79 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
80  >();
81  addAttributes<
82 #define GET_ATTRDEF_LIST
83 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
84  >();
85  addTypes<
86 #define GET_TYPEDEF_LIST
87 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
88  >();
89 
90  addInterface<OpenMPDialectFoldInterface>();
91  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
92  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
93  *getContext());
94 
95  // Attach default offload module interface to module op to access
96  // offload functionality through
97  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
98  *getContext());
99 
100  // Attach default declare target interfaces to operations which can be marked
101  // as declare target (Global Operations and Functions/Subroutines in dialects
102  // that Fortran (or other languages that lower to MLIR) translates too
103  mlir::LLVM::GlobalOp::attachInterface<
105  *getContext());
106  mlir::LLVM::LLVMFuncOp::attachInterface<
108  *getContext());
109  mlir::func::FuncOp::attachInterface<
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Parser and printer for Allocate Clause
115 //===----------------------------------------------------------------------===//
116 
117 /// Parse an allocate clause with allocators and a list of operands with types.
118 ///
119 /// allocate-operand-list :: = allocate-operand |
120 /// allocator-operand `,` allocate-operand-list
121 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
122 /// ssa-id-and-type ::= ssa-id `:` type
124  OpAsmParser &parser,
126  SmallVectorImpl<Type> &typesAllocate,
128  SmallVectorImpl<Type> &typesAllocator) {
129 
130  return parser.parseCommaSeparatedList([&]() {
132  Type type;
133  if (parser.parseOperand(operand) || parser.parseColonType(type))
134  return failure();
135  operandsAllocator.push_back(operand);
136  typesAllocator.push_back(type);
137  if (parser.parseArrow())
138  return failure();
139  if (parser.parseOperand(operand) || parser.parseColonType(type))
140  return failure();
141 
142  operandsAllocate.push_back(operand);
143  typesAllocate.push_back(type);
144  return success();
145  });
146 }
147 
148 /// Print allocate clause
150  OperandRange varsAllocate,
151  TypeRange typesAllocate,
152  OperandRange varsAllocator,
153  TypeRange typesAllocator) {
154  for (unsigned i = 0; i < varsAllocate.size(); ++i) {
155  std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
156  p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
157  p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
158  }
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Parser and printer for a clause attribute (StringEnumAttr)
163 //===----------------------------------------------------------------------===//
164 
165 template <typename ClauseAttr>
166 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
167  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
168  StringRef enumStr;
169  SMLoc loc = parser.getCurrentLocation();
170  if (parser.parseKeyword(&enumStr))
171  return failure();
172  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
173  attr = ClauseAttr::get(parser.getContext(), *enumValue);
174  return success();
175  }
176  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
177 }
178 
179 template <typename ClauseAttr>
180 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
181  p << stringifyEnum(attr.getValue());
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // Parser and printer for Linear Clause
186 //===----------------------------------------------------------------------===//
187 
188 /// linear ::= `linear` `(` linear-list `)`
189 /// linear-list := linear-val | linear-val linear-list
190 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
191 static ParseResult
194  SmallVectorImpl<Type> &types,
196  return parser.parseCommaSeparatedList([&]() {
198  Type type;
200  if (parser.parseOperand(var) || parser.parseEqual() ||
201  parser.parseOperand(stepVar) || parser.parseColonType(type))
202  return failure();
203 
204  vars.push_back(var);
205  types.push_back(type);
206  stepVars.push_back(stepVar);
207  return success();
208  });
209 }
210 
211 /// Print Linear Clause
213  ValueRange linearVars, TypeRange linearVarTypes,
214  ValueRange linearStepVars) {
215  size_t linearVarsSize = linearVars.size();
216  for (unsigned i = 0; i < linearVarsSize; ++i) {
217  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
218  p << linearVars[i];
219  if (linearStepVars.size() > i)
220  p << " = " << linearStepVars[i];
221  p << " : " << linearVars[i].getType() << separator;
222  }
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // Verifier for Nontemporal Clause
227 //===----------------------------------------------------------------------===//
228 
229 static LogicalResult
230 verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables) {
231 
232  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
233  DenseSet<Value> nontemporalItems;
234  for (const auto &it : nontemporalVariables)
235  if (!nontemporalItems.insert(it).second)
236  return op->emitOpError() << "nontemporal variable used more than once";
237 
238  return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Parser, verifier and printer for Aligned Clause
243 //===----------------------------------------------------------------------===//
244 static LogicalResult
245 verifyAlignedClause(Operation *op, std::optional<ArrayAttr> alignmentValues,
246  OperandRange alignedVariables) {
247  // Check if number of alignment values equals to number of aligned variables
248  if (!alignedVariables.empty()) {
249  if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
250  return op->emitOpError()
251  << "expected as many alignment values as aligned variables";
252  } else {
253  if (alignmentValues)
254  return op->emitOpError() << "unexpected alignment values attribute";
255  return success();
256  }
257 
258  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
259  DenseSet<Value> alignedItems;
260  for (auto it : alignedVariables)
261  if (!alignedItems.insert(it).second)
262  return op->emitOpError() << "aligned variable used more than once";
263 
264  if (!alignmentValues)
265  return success();
266 
267  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
268  for (unsigned i = 0; i < (*alignmentValues).size(); ++i) {
269  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
270  if (intAttr.getValue().sle(0))
271  return op->emitOpError() << "alignment should be greater than 0";
272  } else {
273  return op->emitOpError() << "expected integer alignment";
274  }
275  }
276 
277  return success();
278 }
279 
280 /// aligned ::= `aligned` `(` aligned-list `)`
281 /// aligned-list := aligned-val | aligned-val aligned-list
282 /// aligned-val := ssa-id-and-type `->` alignment
284  OpAsmParser &parser,
286  SmallVectorImpl<Type> &types, ArrayAttr &alignmentValues) {
287  SmallVector<Attribute> alignmentVec;
288  if (failed(parser.parseCommaSeparatedList([&]() {
289  if (parser.parseOperand(alignedItems.emplace_back()) ||
290  parser.parseColonType(types.emplace_back()) ||
291  parser.parseArrow() ||
292  parser.parseAttribute(alignmentVec.emplace_back())) {
293  return failure();
294  }
295  return success();
296  })))
297  return failure();
298  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299  alignmentValues = ArrayAttr::get(parser.getContext(), alignments);
300  return success();
301 }
302 
303 /// Print Aligned Clause
305  ValueRange alignedVars,
306  TypeRange alignedVarTypes,
307  std::optional<ArrayAttr> alignmentValues) {
308  for (unsigned i = 0; i < alignedVars.size(); ++i) {
309  if (i != 0)
310  p << ", ";
311  p << alignedVars[i] << " : " << alignedVars[i].getType();
312  p << " -> " << (*alignmentValues)[i];
313  }
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Parser, printer and verifier for Schedule Clause
318 //===----------------------------------------------------------------------===//
319 
320 static ParseResult
322  SmallVectorImpl<SmallString<12>> &modifiers) {
323  if (modifiers.size() > 2)
324  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
325  for (const auto &mod : modifiers) {
326  // Translate the string. If it has no value, then it was not a valid
327  // modifier!
328  auto symbol = symbolizeScheduleModifier(mod);
329  if (!symbol)
330  return parser.emitError(parser.getNameLoc())
331  << " unknown modifier type: " << mod;
332  }
333 
334  // If we have one modifier that is "simd", then stick a "none" modiifer in
335  // index 0.
336  if (modifiers.size() == 1) {
337  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
338  modifiers.push_back(modifiers[0]);
339  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
340  }
341  } else if (modifiers.size() == 2) {
342  // If there are two modifier:
343  // First modifier should not be simd, second one should be simd
344  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
345  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
346  return parser.emitError(parser.getNameLoc())
347  << " incorrect modifier order";
348  }
349  return success();
350 }
351 
352 /// schedule ::= `schedule` `(` sched-list `)`
353 /// sched-list ::= sched-val | sched-val sched-list |
354 /// sched-val `,` sched-modifier
355 /// sched-val ::= sched-with-chunk | sched-wo-chunk
356 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
357 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
358 /// sched-wo-chunk ::= `auto` | `runtime`
359 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
360 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
362  OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
363  ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
364  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
365  StringRef keyword;
366  if (parser.parseKeyword(&keyword))
367  return failure();
368  std::optional<mlir::omp::ClauseScheduleKind> schedule =
369  symbolizeClauseScheduleKind(keyword);
370  if (!schedule)
371  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374  switch (*schedule) {
375  case ClauseScheduleKind::Static:
376  case ClauseScheduleKind::Dynamic:
377  case ClauseScheduleKind::Guided:
378  if (succeeded(parser.parseOptionalEqual())) {
379  chunkSize = OpAsmParser::UnresolvedOperand{};
380  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381  return failure();
382  } else {
383  chunkSize = std::nullopt;
384  }
385  break;
386  case ClauseScheduleKind::Auto:
388  chunkSize = std::nullopt;
389  }
390 
391  // If there is a comma, we have one or more modifiers..
392  SmallVector<SmallString<12>> modifiers;
393  while (succeeded(parser.parseOptionalComma())) {
394  StringRef mod;
395  if (parser.parseKeyword(&mod))
396  return failure();
397  modifiers.push_back(mod);
398  }
399 
400  if (verifyScheduleModifiers(parser, modifiers))
401  return failure();
402 
403  if (!modifiers.empty()) {
404  SMLoc loc = parser.getCurrentLocation();
405  if (std::optional<ScheduleModifier> mod =
406  symbolizeScheduleModifier(modifiers[0])) {
407  scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
408  } else {
409  return parser.emitError(loc, "invalid schedule modifier");
410  }
411  // Only SIMD attribute is allowed here!
412  if (modifiers.size() > 1) {
413  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414  simdModifier = UnitAttr::get(parser.getBuilder().getContext());
415  }
416  }
417 
418  return success();
419 }
420 
421 /// Print schedule clause
423  ClauseScheduleKindAttr schedAttr,
424  ScheduleModifierAttr modifier, UnitAttr simd,
425  Value scheduleChunkVar,
426  Type scheduleChunkType) {
427  p << stringifyClauseScheduleKind(schedAttr.getValue());
428  if (scheduleChunkVar)
429  p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
430  if (modifier)
431  p << ", " << stringifyScheduleModifier(modifier.getValue());
432  if (simd)
433  p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser, printer and verifier for ReductionVarList
438 //===----------------------------------------------------------------------===//
439 
441  OpAsmParser &parser, Region &region,
443  SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
444  ArrayAttr &symbols,
445  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
446  SmallVector<SymbolRefAttr> reductionVec;
447  SmallVector<bool> isByRefVec;
448  unsigned regionArgOffset = regionPrivateArgs.size();
449 
450  if (failed(
452  ParseResult optionalByref = parser.parseOptionalKeyword("byref");
453  if (parser.parseAttribute(reductionVec.emplace_back()) ||
454  parser.parseOperand(operands.emplace_back()) ||
455  parser.parseArrow() ||
456  parser.parseArgument(regionPrivateArgs.emplace_back()) ||
457  parser.parseColonType(types.emplace_back()))
458  return failure();
459  isByRefVec.push_back(optionalByref.succeeded());
460  return success();
461  })))
462  return failure();
463  isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
464 
465  auto *argsBegin = regionPrivateArgs.begin();
466  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
467  argsBegin + regionArgOffset + types.size());
468  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
469  prv.type = type;
470  }
471  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
472  symbols = ArrayAttr::get(parser.getContext(), reductions);
473  return success();
474 }
475 
477  ValueRange argsSubrange,
478  StringRef clauseName, ValueRange operands,
479  TypeRange types, DenseBoolArrayAttr byRef,
480  ArrayAttr symbols) {
481  if (!clauseName.empty())
482  p << clauseName << "(";
483 
484  llvm::interleaveComma(llvm::zip_equal(symbols, operands, argsSubrange, types,
485  byRef.asArrayRef()),
486  p, [&p](auto t) {
487  auto [sym, op, arg, type, isByRef] = t;
488  p << (isByRef ? "byref " : "") << sym << " " << op
489  << " -> " << arg << " : " << type;
490  });
491 
492  if (!clauseName.empty())
493  p << ") ";
494 }
495 
497  OpAsmParser &parser, Region &region,
499  SmallVectorImpl<Type> &reductionVarTypes,
500  DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols,
502  llvm::SmallVectorImpl<Type> &privateVarsTypes,
503  ArrayAttr &privatizerSymbols) {
505 
506  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
507  if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
508  reductionVarTypes, reductionByRef,
509  reductionSymbols, regionPrivateArgs)))
510  return failure();
511  }
512 
513  if (succeeded(parser.parseOptionalKeyword("private"))) {
514  auto privateByRef = DenseBoolArrayAttr::get(parser.getContext(), {});
515  if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
516  privateVarsTypes, privateByRef,
517  privatizerSymbols, regionPrivateArgs)))
518  return failure();
519  if (llvm::any_of(privateByRef.asArrayRef(),
520  [](bool byref) { return byref; })) {
521  parser.emitError(parser.getCurrentLocation(),
522  "private clause cannot have byref attributes");
523  return failure();
524  }
525  }
526 
527  return parser.parseRegion(region, regionPrivateArgs);
528 }
529 
530 static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
531  ValueRange reductionVarOperands,
532  TypeRange reductionVarTypes,
533  DenseBoolArrayAttr reductionVarIsByRef,
534  ArrayAttr reductionSymbols,
535  ValueRange privateVarOperands,
536  TypeRange privateVarTypes,
537  ArrayAttr privatizerSymbols) {
538  if (reductionSymbols) {
539  auto *argsBegin = region.front().getArguments().begin();
540  MutableArrayRef argsSubrange(argsBegin,
541  argsBegin + reductionVarTypes.size());
542  printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
543  reductionVarOperands, reductionVarTypes,
544  reductionVarIsByRef, reductionSymbols);
545  }
546 
547  if (privatizerSymbols) {
548  auto *argsBegin = region.front().getArguments().begin();
549  MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
550  argsBegin + reductionVarOperands.size() +
551  privateVarTypes.size());
552  mlir::SmallVector<bool> isByRefVec;
553  isByRefVec.resize(privateVarTypes.size(), false);
554  DenseBoolArrayAttr isByRef =
555  DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
556 
557  printClauseWithRegionArgs(p, op, argsSubrange, "private",
558  privateVarOperands, privateVarTypes, isByRef,
559  privatizerSymbols);
560  }
561 
562  p.printRegion(region, /*printEntryBlockArgs=*/false);
563 }
564 
565 /// reduction-entry-list ::= reduction-entry
566 /// | reduction-entry-list `,` reduction-entry
567 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
568 static ParseResult
571  SmallVectorImpl<Type> &types,
572  ArrayAttr &redcuctionSymbols) {
573  SmallVector<SymbolRefAttr> reductionVec;
574  if (failed(parser.parseCommaSeparatedList([&]() {
575  if (parser.parseAttribute(reductionVec.emplace_back()) ||
576  parser.parseArrow() ||
577  parser.parseOperand(operands.emplace_back()) ||
578  parser.parseColonType(types.emplace_back()))
579  return failure();
580  return success();
581  })))
582  return failure();
583  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
584  redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
585  return success();
586 }
587 
588 /// Print Reduction clause
590  OperandRange reductionVars,
591  TypeRange reductionTypes,
592  std::optional<ArrayAttr> reductions) {
593  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
594  if (i != 0)
595  p << ", ";
596  p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
597  << reductionVars[i].getType();
598  }
599 }
600 
601 /// Verifies Reduction Clause
602 static LogicalResult
603 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
604  OperandRange reductionVars,
605  std::optional<ArrayRef<bool>> byRef = std::nullopt) {
606  if (!reductionVars.empty()) {
607  if (!reductions || reductions->size() != reductionVars.size())
608  return op->emitOpError()
609  << "expected as many reduction symbol references "
610  "as reduction variables";
611  if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
612  assert(byRef);
613  else
614  assert(!byRef); // TODO: support byref reductions on other operations
615  if (byRef && byRef->size() != reductionVars.size())
616  return op->emitError() << "expected as many reduction variable by "
617  "reference attributes as reduction variables";
618  } else {
619  if (reductions)
620  return op->emitOpError() << "unexpected reduction symbol references";
621  return success();
622  }
623 
624  // TODO: The followings should be done in
625  // SymbolUserOpInterface::verifySymbolUses.
626  DenseSet<Value> accumulators;
627  for (auto args : llvm::zip(reductionVars, *reductions)) {
628  Value accum = std::get<0>(args);
629 
630  if (!accumulators.insert(accum).second)
631  return op->emitOpError() << "accumulator variable used more than once";
632 
633  Type varType = accum.getType();
634  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
635  auto decl =
636  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
637  if (!decl)
638  return op->emitOpError() << "expected symbol reference " << symbolRef
639  << " to point to a reduction declaration";
640 
641  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
642  return op->emitOpError()
643  << "expected accumulator (" << varType
644  << ") to be the same type as reduction declaration ("
645  << decl.getAccumulatorType() << ")";
646  }
647 
648  return success();
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // Parser, printer and verifier for CopyPrivateVarList
653 //===----------------------------------------------------------------------===//
654 
655 /// copyprivate-entry-list ::= copyprivate-entry
656 /// | copyprivate-entry-list `,` copyprivate-entry
657 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
659  OpAsmParser &parser,
661  SmallVectorImpl<Type> &types, ArrayAttr &copyPrivateSymbols) {
662  SmallVector<SymbolRefAttr> copyPrivateFuncsVec;
663  if (failed(parser.parseCommaSeparatedList([&]() {
664  if (parser.parseOperand(operands.emplace_back()) ||
665  parser.parseArrow() ||
666  parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
667  parser.parseColonType(types.emplace_back()))
668  return failure();
669  return success();
670  })))
671  return failure();
672  SmallVector<Attribute> copyPrivateFuncs(copyPrivateFuncsVec.begin(),
673  copyPrivateFuncsVec.end());
674  copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs);
675  return success();
676 }
677 
678 /// Print CopyPrivate clause
680  OperandRange copyPrivateVars,
681  TypeRange copyPrivateTypes,
682  std::optional<ArrayAttr> copyPrivateFuncs) {
683  if (!copyPrivateFuncs.has_value())
684  return;
685  llvm::interleaveComma(
686  llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
687  [&](const auto &args) {
688  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
689  << std::get<2>(args);
690  });
691 }
692 
693 /// Verifies CopyPrivate Clause
694 static LogicalResult
696  std::optional<ArrayAttr> copyPrivateFuncs) {
697  size_t copyPrivateFuncsSize =
698  copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
699  if (copyPrivateFuncsSize != copyPrivateVars.size())
700  return op->emitOpError() << "inconsistent number of copyPrivate vars (= "
701  << copyPrivateVars.size()
702  << ") and functions (= " << copyPrivateFuncsSize
703  << "), both must be equal";
704  if (!copyPrivateFuncs.has_value())
705  return success();
706 
707  for (auto copyPrivateVarAndFunc :
708  llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
709  auto symbolRef =
710  llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
711  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
712  funcOp;
713  if (mlir::func::FuncOp mlirFuncOp =
714  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
715  symbolRef))
716  funcOp = mlirFuncOp;
717  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
718  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
719  op, symbolRef))
720  funcOp = llvmFuncOp;
721 
722  auto getNumArguments = [&] {
723  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
724  };
725 
726  auto getArgumentType = [&](unsigned i) {
727  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
728  *funcOp);
729  };
730 
731  if (!funcOp)
732  return op->emitOpError() << "expected symbol reference " << symbolRef
733  << " to point to a copy function";
734 
735  if (getNumArguments() != 2)
736  return op->emitOpError()
737  << "expected copy function " << symbolRef << " to have 2 operands";
738 
739  Type argTy = getArgumentType(0);
740  if (argTy != getArgumentType(1))
741  return op->emitOpError() << "expected copy function " << symbolRef
742  << " arguments to have the same type";
743 
744  Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
745  if (argTy != varType)
746  return op->emitOpError()
747  << "expected copy function arguments' type (" << argTy
748  << ") to be the same as copyprivate variable's type (" << varType
749  << ")";
750  }
751 
752  return success();
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // Parser, printer and verifier for DependVarList
757 //===----------------------------------------------------------------------===//
758 
759 /// depend-entry-list ::= depend-entry
760 /// | depend-entry-list `,` depend-entry
761 /// depend-entry ::= depend-kind `->` ssa-id `:` type
762 static ParseResult
765  SmallVectorImpl<Type> &types, ArrayAttr &dependsArray) {
767  if (failed(parser.parseCommaSeparatedList([&]() {
768  StringRef keyword;
769  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
770  parser.parseOperand(operands.emplace_back()) ||
771  parser.parseColonType(types.emplace_back()))
772  return failure();
773  if (std::optional<ClauseTaskDepend> keywordDepend =
774  (symbolizeClauseTaskDepend(keyword)))
775  dependVec.emplace_back(
776  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
777  else
778  return failure();
779  return success();
780  })))
781  return failure();
782  SmallVector<Attribute> depends(dependVec.begin(), dependVec.end());
783  dependsArray = ArrayAttr::get(parser.getContext(), depends);
784  return success();
785 }
786 
787 /// Print Depend clause
789  OperandRange dependVars, TypeRange dependTypes,
790  std::optional<ArrayAttr> depends) {
791 
792  for (unsigned i = 0, e = depends->size(); i < e; ++i) {
793  if (i != 0)
794  p << ", ";
795  p << stringifyClauseTaskDepend(
796  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
797  .getValue())
798  << " -> " << dependVars[i] << " : " << dependTypes[i];
799  }
800 }
801 
802 /// Verifies Depend clause
804  std::optional<ArrayAttr> depends,
805  OperandRange dependVars) {
806  if (!dependVars.empty()) {
807  if (!depends || depends->size() != dependVars.size())
808  return op->emitOpError() << "expected as many depend values"
809  " as depend variables";
810  } else {
811  if (depends && !depends->empty())
812  return op->emitOpError() << "unexpected depend values";
813  return success();
814  }
815 
816  return success();
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // Parser, printer and verifier for Synchronization Hint (2.17.12)
821 //===----------------------------------------------------------------------===//
822 
823 /// Parses a Synchronization Hint clause. The value of hint is an integer
824 /// which is a combination of different hints from `omp_sync_hint_t`.
825 ///
826 /// hint-clause = `hint` `(` hint-value `)`
828  IntegerAttr &hintAttr) {
829  StringRef hintKeyword;
830  int64_t hint = 0;
831  if (succeeded(parser.parseOptionalKeyword("none"))) {
832  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
833  return success();
834  }
835  auto parseKeyword = [&]() -> ParseResult {
836  if (failed(parser.parseKeyword(&hintKeyword)))
837  return failure();
838  if (hintKeyword == "uncontended")
839  hint |= 1;
840  else if (hintKeyword == "contended")
841  hint |= 2;
842  else if (hintKeyword == "nonspeculative")
843  hint |= 4;
844  else if (hintKeyword == "speculative")
845  hint |= 8;
846  else
847  return parser.emitError(parser.getCurrentLocation())
848  << hintKeyword << " is not a valid hint";
849  return success();
850  };
851  if (parser.parseCommaSeparatedList(parseKeyword))
852  return failure();
853  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
854  return success();
855 }
856 
857 /// Prints a Synchronization Hint clause
859  IntegerAttr hintAttr) {
860  int64_t hint = hintAttr.getInt();
861 
862  if (hint == 0) {
863  p << "none";
864  return;
865  }
866 
867  // Helper function to get n-th bit from the right end of `value`
868  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
869 
870  bool uncontended = bitn(hint, 0);
871  bool contended = bitn(hint, 1);
872  bool nonspeculative = bitn(hint, 2);
873  bool speculative = bitn(hint, 3);
874 
876  if (uncontended)
877  hints.push_back("uncontended");
878  if (contended)
879  hints.push_back("contended");
880  if (nonspeculative)
881  hints.push_back("nonspeculative");
882  if (speculative)
883  hints.push_back("speculative");
884 
885  llvm::interleaveComma(hints, p);
886 }
887 
888 /// Verifies a synchronization hint clause
890 
891  // Helper function to get n-th bit from the right end of `value`
892  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
893 
894  bool uncontended = bitn(hint, 0);
895  bool contended = bitn(hint, 1);
896  bool nonspeculative = bitn(hint, 2);
897  bool speculative = bitn(hint, 3);
898 
899  if (uncontended && contended)
900  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
901  "omp_sync_hint_contended cannot be combined";
902  if (nonspeculative && speculative)
903  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
904  "omp_sync_hint_speculative cannot be combined.";
905  return success();
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // Parser, printer and verifier for Target
910 //===----------------------------------------------------------------------===//
911 
912 // Helper function to get bitwise AND of `value` and 'flag'
913 uint64_t mapTypeToBitFlag(uint64_t value,
914  llvm::omp::OpenMPOffloadMappingFlags flag) {
915  return value & llvm::to_underlying(flag);
916 }
917 
918 /// Parses a map_entries map type from a string format back into its numeric
919 /// value.
920 ///
921 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
922 /// `to` | `from` | `delete` `)` )+ `)` )
923 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
924  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
925  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
926 
927  // This simply verifies the correct keyword is read in, the
928  // keyword itself is stored inside of the operation
929  auto parseTypeAndMod = [&]() -> ParseResult {
930  StringRef mapTypeMod;
931  if (parser.parseKeyword(&mapTypeMod))
932  return failure();
933 
934  if (mapTypeMod == "always")
935  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
936 
937  if (mapTypeMod == "implicit")
938  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
939 
940  if (mapTypeMod == "close")
941  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
942 
943  if (mapTypeMod == "present")
944  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
945 
946  if (mapTypeMod == "to")
947  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
948 
949  if (mapTypeMod == "from")
950  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
951 
952  if (mapTypeMod == "tofrom")
953  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
954  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
955 
956  if (mapTypeMod == "delete")
957  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
958 
959  return success();
960  };
961 
962  if (parser.parseCommaSeparatedList(parseTypeAndMod))
963  return failure();
964 
965  mapType = parser.getBuilder().getIntegerAttr(
966  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
967  llvm::to_underlying(mapTypeBits));
968 
969  return success();
970 }
971 
972 /// Prints a map_entries map type from its numeric value out into its string
973 /// format.
975  IntegerAttr mapType) {
976  uint64_t mapTypeBits = mapType.getUInt();
977 
978  bool emitAllocRelease = true;
980 
981  // handling of always, close, present placed at the beginning of the string
982  // to aid readability
983  if (mapTypeToBitFlag(mapTypeBits,
984  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
985  mapTypeStrs.push_back("always");
986  if (mapTypeToBitFlag(mapTypeBits,
987  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
988  mapTypeStrs.push_back("implicit");
989  if (mapTypeToBitFlag(mapTypeBits,
990  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
991  mapTypeStrs.push_back("close");
992  if (mapTypeToBitFlag(mapTypeBits,
993  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
994  mapTypeStrs.push_back("present");
995 
996  // special handling of to/from/tofrom/delete and release/alloc, release +
997  // alloc are the abscense of one of the other flags, whereas tofrom requires
998  // both the to and from flag to be set.
999  bool to = mapTypeToBitFlag(mapTypeBits,
1000  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1001  bool from = mapTypeToBitFlag(
1002  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1003  if (to && from) {
1004  emitAllocRelease = false;
1005  mapTypeStrs.push_back("tofrom");
1006  } else if (from) {
1007  emitAllocRelease = false;
1008  mapTypeStrs.push_back("from");
1009  } else if (to) {
1010  emitAllocRelease = false;
1011  mapTypeStrs.push_back("to");
1012  }
1013  if (mapTypeToBitFlag(mapTypeBits,
1014  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1015  emitAllocRelease = false;
1016  mapTypeStrs.push_back("delete");
1017  }
1018  if (emitAllocRelease)
1019  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1020 
1021  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1022  p << mapTypeStrs[i];
1023  if (i + 1 < mapTypeStrs.size()) {
1024  p << ", ";
1025  }
1026  }
1027 }
1028 
1030  DenseIntElementsAttr &membersIdx) {
1031  SmallVector<APInt> values;
1032  int64_t value;
1033  int64_t shape[2] = {0, 0};
1034  unsigned shapeTmp = 0;
1035  auto parseIndices = [&]() -> ParseResult {
1036  if (parser.parseInteger(value))
1037  return failure();
1038  shapeTmp++;
1039  values.push_back(APInt(32, value));
1040  return success();
1041  };
1042 
1043  do {
1044  if (failed(parser.parseLSquare()))
1045  return failure();
1046 
1047  if (parser.parseCommaSeparatedList(parseIndices))
1048  return failure();
1049 
1050  if (failed(parser.parseRSquare()))
1051  return failure();
1052 
1053  // Only set once, if any indices are not the same size
1054  // we error out in the next check as that's unsupported
1055  if (shape[1] == 0)
1056  shape[1] = shapeTmp;
1057 
1058  // Verify that the recently parsed list is equal to the
1059  // first one we parsed, they must be equal lengths to
1060  // keep the rectangular shape DenseIntElementsAttr
1061  // requires
1062  if (shapeTmp != shape[1])
1063  return failure();
1064 
1065  shapeTmp = 0;
1066  shape[0]++;
1067  } while (succeeded(parser.parseOptionalComma()));
1068 
1069  if (!values.empty()) {
1070  ShapedType valueType =
1071  VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1072  membersIdx = DenseIntElementsAttr::get(valueType, values);
1073  }
1074 
1075  return success();
1076 }
1077 
1078 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1079  DenseIntElementsAttr membersIdx) {
1080  llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1081  assert(shape.size() <= 2);
1082 
1083  if (!membersIdx)
1084  return;
1085 
1086  for (int i = 0; i < shape[0]; ++i) {
1087  p << "[";
1088  int rowOffset = i * shape[1];
1089  for (int j = 0; j < shape[1]; ++j) {
1090  p << membersIdx.getValues<
1091  int32_t>()[rowOffset + j];
1092  if ((j + 1) < shape[1])
1093  p << ",";
1094  }
1095  p << "]";
1096 
1097  if ((i + 1) < shape[0])
1098  p << ", ";
1099  }
1100 }
1101 
1102 static ParseResult
1105  SmallVectorImpl<Type> &mapOperandTypes) {
1108  Type argType;
1109  auto parseEntries = [&]() -> ParseResult {
1110  if (parser.parseOperand(arg) || parser.parseArrow() ||
1111  parser.parseOperand(blockArg))
1112  return failure();
1113  mapOperands.push_back(arg);
1114  return success();
1115  };
1116 
1117  auto parseTypes = [&]() -> ParseResult {
1118  if (parser.parseType(argType))
1119  return failure();
1120  mapOperandTypes.push_back(argType);
1121  return success();
1122  };
1123 
1124  if (parser.parseCommaSeparatedList(parseEntries))
1125  return failure();
1126 
1127  if (parser.parseColon())
1128  return failure();
1129 
1130  if (parser.parseCommaSeparatedList(parseTypes))
1131  return failure();
1132 
1133  return success();
1134 }
1135 
1137  OperandRange mapOperands,
1138  TypeRange mapOperandTypes) {
1139  auto &region = op->getRegion(0);
1140  unsigned argIndex = 0;
1141 
1142  for (const auto &mapOp : mapOperands) {
1143  const auto &blockArg = region.front().getArgument(argIndex);
1144  p << mapOp << " -> " << blockArg;
1145  argIndex++;
1146  if (argIndex < mapOperands.size())
1147  p << ", ";
1148  }
1149  p << " : ";
1150 
1151  argIndex = 0;
1152  for (const auto &mapType : mapOperandTypes) {
1153  p << mapType;
1154  argIndex++;
1155  if (argIndex < mapOperands.size())
1156  p << ", ";
1157  }
1158 }
1159 
1161  OpAsmParser &parser,
1163  SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
1164  SmallVector<SymbolRefAttr> privateSymRefs;
1165  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
1166 
1167  if (failed(parser.parseCommaSeparatedList([&]() {
1168  if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
1169  parser.parseOperand(privateOperands.emplace_back()) ||
1170  parser.parseArrow() ||
1171  parser.parseArgument(regionPrivateArgs.emplace_back()) ||
1172  parser.parseColonType(privateOperandTypes.emplace_back()))
1173  return failure();
1174  return success();
1175  })))
1176  return failure();
1177 
1178  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
1179  privateSymRefs.end());
1180  privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
1181 
1182  return success();
1183 }
1184 
1186  ValueRange privateVarOperands,
1187  TypeRange privateVarTypes,
1188  ArrayAttr privatizerSymbols) {
1189  // TODO: Remove target-specific logic from this function.
1190  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1191  assert(targetOp);
1192 
1193  auto &region = op->getRegion(0);
1194  auto *argsBegin = region.front().getArguments().begin();
1195  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
1196  argsBegin + targetOp.getMapOperands().size() +
1197  privateVarTypes.size());
1198  mlir::SmallVector<bool> isByRefVec;
1199  isByRefVec.resize(privateVarTypes.size(), false);
1200  DenseBoolArrayAttr isByRef =
1201  DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
1202 
1204  p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
1205  privateVarTypes, isByRef, privatizerSymbols);
1206 }
1207 
1209  VariableCaptureKindAttr mapCaptureType) {
1210  std::string typeCapStr;
1211  llvm::raw_string_ostream typeCap(typeCapStr);
1212  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1213  typeCap << "ByRef";
1214  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1215  typeCap << "ByCopy";
1216  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1217  typeCap << "VLAType";
1218  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1219  typeCap << "This";
1220  p << typeCap.str();
1221 }
1222 
1224  VariableCaptureKindAttr &mapCapture) {
1225  StringRef mapCaptureKey;
1226  if (parser.parseKeyword(&mapCaptureKey))
1227  return failure();
1228 
1229  if (mapCaptureKey == "This")
1231  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1232  if (mapCaptureKey == "ByRef")
1234  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1235  if (mapCaptureKey == "ByCopy")
1237  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1238  if (mapCaptureKey == "VLAType")
1240  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1241 
1242  return success();
1243 }
1244 
1248 
1249  for (auto mapOp : mapOperands) {
1250  if (!mapOp.getDefiningOp())
1251  emitError(op->getLoc(), "missing map operation");
1252 
1253  if (auto mapInfoOp =
1254  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1255  if (!mapInfoOp.getMapType().has_value())
1256  emitError(op->getLoc(), "missing map type for map operand");
1257 
1258  if (!mapInfoOp.getMapCaptureType().has_value())
1259  emitError(op->getLoc(), "missing map capture type for map operand");
1260 
1261  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1262 
1263  bool to = mapTypeToBitFlag(
1264  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1265  bool from = mapTypeToBitFlag(
1266  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1267  bool del = mapTypeToBitFlag(
1268  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1269 
1270  bool always = mapTypeToBitFlag(
1271  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1272  bool close = mapTypeToBitFlag(
1273  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1274  bool implicit = mapTypeToBitFlag(
1275  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1276 
1277  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1278  return emitError(op->getLoc(),
1279  "to, from, tofrom and alloc map types are permitted");
1280 
1281  if (isa<TargetEnterDataOp>(op) && (from || del))
1282  return emitError(op->getLoc(), "to and alloc map types are permitted");
1283 
1284  if (isa<TargetExitDataOp>(op) && to)
1285  return emitError(op->getLoc(),
1286  "from, release and delete map types are permitted");
1287 
1288  if (isa<TargetUpdateOp>(op)) {
1289  if (del) {
1290  return emitError(op->getLoc(),
1291  "at least one of to or from map types must be "
1292  "specified, other map types are not permitted");
1293  }
1294 
1295  if (!to && !from) {
1296  return emitError(op->getLoc(),
1297  "at least one of to or from map types must be "
1298  "specified, other map types are not permitted");
1299  }
1300 
1301  auto updateVar = mapInfoOp.getVarPtr();
1302 
1303  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1304  (from && updateToVars.contains(updateVar))) {
1305  return emitError(
1306  op->getLoc(),
1307  "either to or from map types can be specified, not both");
1308  }
1309 
1310  if (always || close || implicit) {
1311  return emitError(
1312  op->getLoc(),
1313  "present, mapper and iterator map type modifiers are permitted");
1314  }
1315 
1316  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1317  }
1318  } else {
1319  emitError(op->getLoc(), "map argument is not a map entry operation");
1320  }
1321  }
1322 
1323  return success();
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // TargetDataOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1331  const TargetDataClauseOps &clauses) {
1332  TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1333  clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
1334  clauses.mapVars);
1335 }
1336 
1338  if (getMapOperands().empty() && getUseDevicePtr().empty() &&
1339  getUseDeviceAddr().empty()) {
1340  return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or "
1341  "useDeviceAddr operand must be present");
1342  }
1343  return verifyMapClause(*this, getMapOperands());
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // TargetEnterDataOp
1348 //===----------------------------------------------------------------------===//
1349 
1350 void TargetEnterDataOp::build(
1351  OpBuilder &builder, OperationState &state,
1352  const TargetEnterExitUpdateDataClauseOps &clauses) {
1353  MLIRContext *ctx = builder.getContext();
1354  TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1355  makeArrayAttr(ctx, clauses.dependTypeAttrs),
1356  clauses.dependVars, clauses.nowaitAttr,
1357  clauses.mapVars);
1358 }
1359 
1361  LogicalResult verifyDependVars =
1362  verifyDependVarList(*this, getDepends(), getDependVars());
1363  return failed(verifyDependVars) ? verifyDependVars
1364  : verifyMapClause(*this, getMapOperands());
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // TargetExitDataOp
1369 //===----------------------------------------------------------------------===//
1370 
1371 void TargetExitDataOp::build(
1372  OpBuilder &builder, OperationState &state,
1373  const TargetEnterExitUpdateDataClauseOps &clauses) {
1374  MLIRContext *ctx = builder.getContext();
1375  TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1376  makeArrayAttr(ctx, clauses.dependTypeAttrs),
1377  clauses.dependVars, clauses.nowaitAttr,
1378  clauses.mapVars);
1379 }
1380 
1382  LogicalResult verifyDependVars =
1383  verifyDependVarList(*this, getDepends(), getDependVars());
1384  return failed(verifyDependVars) ? verifyDependVars
1385  : verifyMapClause(*this, getMapOperands());
1386 }
1387 
1388 //===----------------------------------------------------------------------===//
1389 // TargetUpdateOp
1390 //===----------------------------------------------------------------------===//
1391 
1392 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1393  const TargetEnterExitUpdateDataClauseOps &clauses) {
1394  MLIRContext *ctx = builder.getContext();
1395  TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1396  makeArrayAttr(ctx, clauses.dependTypeAttrs),
1397  clauses.dependVars, clauses.nowaitAttr,
1398  clauses.mapVars);
1399 }
1400 
1402  LogicalResult verifyDependVars =
1403  verifyDependVarList(*this, getDepends(), getDependVars());
1404  return failed(verifyDependVars) ? verifyDependVars
1405  : verifyMapClause(*this, getMapOperands());
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // TargetOp
1410 //===----------------------------------------------------------------------===//
1411 
1412 void TargetOp::build(OpBuilder &builder, OperationState &state,
1413  const TargetClauseOps &clauses) {
1414  MLIRContext *ctx = builder.getContext();
1415  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1416  // inReductionDeclSymbols, reductionVars, reductionByRefAttr,
1417  // reductionDeclSymbols.
1418  TargetOp::build(
1419  builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
1420  makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1421  clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
1422  clauses.mapVars, clauses.privateVars,
1423  makeArrayAttr(ctx, clauses.privatizers));
1424 }
1425 
1427  LogicalResult verifyDependVars =
1428  verifyDependVarList(*this, getDepends(), getDependVars());
1429  return failed(verifyDependVars) ? verifyDependVars
1430  : verifyMapClause(*this, getMapOperands());
1431 }
1432 
1433 //===----------------------------------------------------------------------===//
1434 // ParallelOp
1435 //===----------------------------------------------------------------------===//
1436 
1437 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1438  ArrayRef<NamedAttribute> attributes) {
1439  ParallelOp::build(
1440  builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
1441  /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
1442  /*reduction_vars=*/ValueRange(), /*reduction_vars_byref=*/nullptr,
1443  /*reductions=*/nullptr,
1444  /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1445  /*privatizers=*/nullptr);
1446  state.addAttributes(attributes);
1447 }
1448 
1449 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1450  const ParallelClauseOps &clauses) {
1451  MLIRContext *ctx = builder.getContext();
1452 
1453  ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
1454  clauses.allocateVars, clauses.allocatorVars,
1455  clauses.reductionVars,
1456  DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
1457  makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1458  clauses.procBindKindAttr, clauses.privateVars,
1459  makeArrayAttr(ctx, clauses.privatizers));
1460 }
1461 
1462 template <typename OpType>
1464  auto privateVars = op.getPrivateVars();
1465  auto privatizers = op.getPrivatizersAttr();
1466 
1467  if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1468  return success();
1469 
1470  auto numPrivateVars = privateVars.size();
1471  auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1472 
1473  if (numPrivateVars != numPrivatizers)
1474  return op.emitError() << "inconsistent number of private variables and "
1475  "privatizer op symbols, private vars: "
1476  << numPrivateVars
1477  << " vs. privatizer op symbols: " << numPrivatizers;
1478 
1479  for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1480  Type varType = std::get<0>(privateVarInfo).getType();
1481  SymbolRefAttr privatizerSym =
1482  cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1483  PrivateClauseOp privatizerOp =
1484  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1485  privatizerSym);
1486 
1487  if (privatizerOp == nullptr)
1488  return op.emitError() << "failed to lookup privatizer op with symbol: '"
1489  << privatizerSym << "'";
1490 
1491  Type privatizerType = privatizerOp.getType();
1492 
1493  if (varType != privatizerType)
1494  return op.emitError()
1495  << "type mismatch between a "
1496  << (privatizerOp.getDataSharingType() ==
1497  DataSharingClauseType::Private
1498  ? "private"
1499  : "firstprivate")
1500  << " variable and its privatizer op, var type: " << varType
1501  << " vs. privatizer op type: " << privatizerType;
1502  }
1503 
1504  return success();
1505 }
1506 
1508  // Check that it is a valid loop wrapper if it's taking that role.
1509  if (isa<DistributeOp>((*this)->getParentOp())) {
1510  if (!isWrapper())
1511  return emitOpError() << "must take a loop wrapper role if nested inside "
1512  "of 'omp.distribute'";
1513 
1514  if (LoopWrapperInterface nested = getNestedWrapper()) {
1515  // Check for the allowed leaf constructs that may appear in a composite
1516  // construct directly after PARALLEL.
1517  if (!isa<WsloopOp>(nested))
1518  return emitError() << "only supported nested wrapper is 'omp.wsloop'";
1519  } else {
1520  return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
1521  }
1522  }
1523 
1524  if (getAllocateVars().size() != getAllocatorsVars().size())
1525  return emitError(
1526  "expected equal sizes for allocate and allocator variables");
1527 
1528  if (failed(verifyPrivateVarList(*this)))
1529  return failure();
1530 
1531  return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1532  getReductionVarsByref());
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // TeamsOp
1537 //===----------------------------------------------------------------------===//
1538 
1540  while ((op = op->getParentOp()))
1541  if (isa<OpenMPDialect>(op->getDialect()))
1542  return false;
1543  return true;
1544 }
1545 
1546 void TeamsOp::build(OpBuilder &builder, OperationState &state,
1547  const TeamsClauseOps &clauses) {
1548  MLIRContext *ctx = builder.getContext();
1549  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1550  TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
1551  clauses.numTeamsUpperVar, clauses.ifVar,
1552  clauses.threadLimitVar, clauses.allocateVars,
1553  clauses.allocatorVars, clauses.reductionVars,
1554  makeArrayAttr(ctx, clauses.reductionDeclSymbols));
1555 }
1556 
1558  // Check parent region
1559  // TODO If nested inside of a target region, also check that it does not
1560  // contain any statements, declarations or directives other than this
1561  // omp.teams construct. The issue is how to support the initialization of
1562  // this operation's own arguments (allow SSA values across omp.target?).
1563  Operation *op = getOperation();
1564  if (!isa<TargetOp>(op->getParentOp()) &&
1566  return emitError("expected to be nested inside of omp.target or not nested "
1567  "in any OpenMP dialect operations");
1568 
1569  // Check for num_teams clause restrictions
1570  if (auto numTeamsLowerBound = getNumTeamsLower()) {
1571  auto numTeamsUpperBound = getNumTeamsUpper();
1572  if (!numTeamsUpperBound)
1573  return emitError("expected num_teams upper bound to be defined if the "
1574  "lower bound is defined");
1575  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1576  return emitError(
1577  "expected num_teams upper bound and lower bound to be the same type");
1578  }
1579 
1580  // Check for allocate clause restrictions
1581  if (getAllocateVars().size() != getAllocatorsVars().size())
1582  return emitError(
1583  "expected equal sizes for allocate and allocator variables");
1584 
1585  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1586 }
1587 
1588 //===----------------------------------------------------------------------===//
1589 // SectionsOp
1590 //===----------------------------------------------------------------------===//
1591 
1592 void SectionsOp::build(OpBuilder &builder, OperationState &state,
1593  const SectionsClauseOps &clauses) {
1594  MLIRContext *ctx = builder.getContext();
1595  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1596  SectionsOp::build(builder, state, clauses.reductionVars,
1597  makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1598  clauses.allocateVars, clauses.allocatorVars,
1599  clauses.nowaitAttr);
1600 }
1601 
1603  if (getAllocateVars().size() != getAllocatorsVars().size())
1604  return emitError(
1605  "expected equal sizes for allocate and allocator variables");
1606 
1607  return verifyReductionVarList(*this, getReductions(), getReductionVars());
1608 }
1609 
1610 LogicalResult SectionsOp::verifyRegions() {
1611  for (auto &inst : *getRegion().begin()) {
1612  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1613  return emitOpError()
1614  << "expected omp.section op or terminator op inside region";
1615  }
1616  }
1617 
1618  return success();
1619 }
1620 
1621 //===----------------------------------------------------------------------===//
1622 // SingleOp
1623 //===----------------------------------------------------------------------===//
1624 
1625 void SingleOp::build(OpBuilder &builder, OperationState &state,
1626  const SingleClauseOps &clauses) {
1627  MLIRContext *ctx = builder.getContext();
1628  // TODO Store clauses in op: privateVars, privatizers.
1629  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1630  clauses.copyprivateVars,
1631  makeArrayAttr(ctx, clauses.copyprivateFuncs),
1632  clauses.nowaitAttr);
1633 }
1634 
1636  // Check for allocate clause restrictions
1637  if (getAllocateVars().size() != getAllocatorsVars().size())
1638  return emitError(
1639  "expected equal sizes for allocate and allocator variables");
1640 
1641  return verifyCopyPrivateVarList(*this, getCopyprivateVars(),
1642  getCopyprivateFuncs());
1643 }
1644 
1645 //===----------------------------------------------------------------------===//
1646 // WsloopOp
1647 //===----------------------------------------------------------------------===//
1648 
1650 parseWsloop(OpAsmParser &parser, Region &region,
1652  SmallVectorImpl<Type> &reductionTypes,
1653  DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols) {
1654  // Parse an optional reduction clause
1656  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
1657  if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
1658  reductionTypes, reductionByRef,
1659  reductionSymbols, privates)))
1660  return failure();
1661  }
1662  return parser.parseRegion(region, privates);
1663 }
1664 
1666  ValueRange reductionOperands, TypeRange reductionTypes,
1667  DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols) {
1668  if (reductionSymbols) {
1669  auto reductionArgs = region.front().getArguments();
1670  printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1671  reductionOperands, reductionTypes, isByRef,
1672  reductionSymbols);
1673  }
1674  p.printRegion(region, /*printEntryBlockArgs=*/false);
1675 }
1676 
1677 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1678  ArrayRef<NamedAttribute> attributes) {
1679  build(builder, state, /*linear_vars=*/ValueRange(),
1680  /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
1681  /*reduction_vars_byref=*/nullptr,
1682  /*reductions=*/nullptr, /*schedule_val=*/nullptr,
1683  /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
1684  /*simd_modifier=*/false, /*nowait=*/false,
1685  /*ordered_val=*/nullptr, /*order_val=*/nullptr);
1686  state.addAttributes(attributes);
1687 }
1688 
1689 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1690  const WsloopClauseOps &clauses) {
1691  MLIRContext *ctx = builder.getContext();
1692  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
1693  // privatizers.
1694  WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
1695  clauses.reductionVars,
1696  DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
1697  makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1698  clauses.scheduleValAttr, clauses.scheduleChunkVar,
1699  clauses.scheduleModAttr, clauses.scheduleSimdAttr,
1700  clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr);
1701 }
1702 
1704  if (!isWrapper())
1705  return emitOpError() << "must be a loop wrapper";
1706 
1707  if (LoopWrapperInterface nested = getNestedWrapper()) {
1708  // Check for the allowed leaf constructs that may appear in a composite
1709  // construct directly after DO/FOR.
1710  if (!isa<SimdOp>(nested))
1711  return emitError() << "only supported nested wrapper is 'omp.simd'";
1712  }
1713 
1714  return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1715  getReductionVarsByref());
1716 }
1717 
1718 //===----------------------------------------------------------------------===//
1719 // Simd construct [2.9.3.1]
1720 //===----------------------------------------------------------------------===//
1721 
1722 void SimdOp::build(OpBuilder &builder, OperationState &state,
1723  const SimdClauseOps &clauses) {
1724  MLIRContext *ctx = builder.getContext();
1725  // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
1726  // privatizers, reductionDeclSymbols.
1727  SimdOp::build(builder, state, clauses.alignedVars,
1728  makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar,
1729  clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr,
1730  clauses.safelenAttr);
1731 }
1732 
1734  if (getSimdlen().has_value() && getSafelen().has_value() &&
1735  getSimdlen().value() > getSafelen().value())
1736  return emitOpError()
1737  << "simdlen clause and safelen clause are both present, but the "
1738  "simdlen value is not less than or equal to safelen value";
1739 
1740  if (verifyAlignedClause(*this, getAlignmentValues(), getAlignedVars())
1741  .failed())
1742  return failure();
1743 
1744  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
1745  return failure();
1746 
1747  if (!isWrapper())
1748  return emitOpError() << "must be a loop wrapper";
1749 
1750  if (getNestedWrapper())
1751  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
1752 
1753  return success();
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // Distribute construct [2.9.4.1]
1758 //===----------------------------------------------------------------------===//
1759 
1760 void DistributeOp::build(OpBuilder &builder, OperationState &state,
1761  const DistributeClauseOps &clauses) {
1762  // TODO Store clauses in op: privateVars, privatizers.
1763  DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
1764  clauses.distScheduleChunkSizeVar, clauses.allocateVars,
1765  clauses.allocatorVars, clauses.orderAttr);
1766 }
1767 
1769  if (this->getChunkSize() && !this->getDistScheduleStatic())
1770  return emitOpError() << "chunk size set without "
1771  "dist_schedule_static being present";
1772 
1773  if (getAllocateVars().size() != getAllocatorsVars().size())
1774  return emitError(
1775  "expected equal sizes for allocate and allocator variables");
1776 
1777  if (!isWrapper())
1778  return emitOpError() << "must be a loop wrapper";
1779 
1780  if (LoopWrapperInterface nested = getNestedWrapper()) {
1781  // Check for the allowed leaf constructs that may appear in a composite
1782  // construct directly after DISTRIBUTE.
1783  if (!isa<ParallelOp, SimdOp>(nested))
1784  return emitError() << "only supported nested wrappers are 'omp.parallel' "
1785  "and 'omp.simd'";
1786  }
1787 
1788  return success();
1789 }
1790 
1791 //===----------------------------------------------------------------------===//
1792 // DeclareReductionOp
1793 //===----------------------------------------------------------------------===//
1794 
1796  Region &region) {
1797  if (parser.parseOptionalKeyword("atomic"))
1798  return success();
1799  return parser.parseRegion(region);
1800 }
1801 
1803  DeclareReductionOp op, Region &region) {
1804  if (region.empty())
1805  return;
1806  printer << "atomic ";
1807  printer.printRegion(region);
1808 }
1809 
1811  Region &region) {
1812  if (parser.parseOptionalKeyword("cleanup"))
1813  return success();
1814  return parser.parseRegion(region);
1815 }
1816 
1818  DeclareReductionOp op, Region &region) {
1819  if (region.empty())
1820  return;
1821  printer << "cleanup ";
1822  printer.printRegion(region);
1823 }
1824 
1825 LogicalResult DeclareReductionOp::verifyRegions() {
1826  if (getInitializerRegion().empty())
1827  return emitOpError() << "expects non-empty initializer region";
1828  Block &initializerEntryBlock = getInitializerRegion().front();
1829  if (initializerEntryBlock.getNumArguments() != 1 ||
1830  initializerEntryBlock.getArgument(0).getType() != getType()) {
1831  return emitOpError() << "expects initializer region with one argument "
1832  "of the reduction type";
1833  }
1834 
1835  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1836  if (yieldOp.getResults().size() != 1 ||
1837  yieldOp.getResults().getTypes()[0] != getType())
1838  return emitOpError() << "expects initializer region to yield a value "
1839  "of the reduction type";
1840  }
1841 
1842  if (getReductionRegion().empty())
1843  return emitOpError() << "expects non-empty reduction region";
1844  Block &reductionEntryBlock = getReductionRegion().front();
1845  if (reductionEntryBlock.getNumArguments() != 2 ||
1846  reductionEntryBlock.getArgumentTypes()[0] !=
1847  reductionEntryBlock.getArgumentTypes()[1] ||
1848  reductionEntryBlock.getArgumentTypes()[0] != getType())
1849  return emitOpError() << "expects reduction region with two arguments of "
1850  "the reduction type";
1851  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1852  if (yieldOp.getResults().size() != 1 ||
1853  yieldOp.getResults().getTypes()[0] != getType())
1854  return emitOpError() << "expects reduction region to yield a value "
1855  "of the reduction type";
1856  }
1857 
1858  if (!getAtomicReductionRegion().empty()) {
1859  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1860  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1861  atomicReductionEntryBlock.getArgumentTypes()[0] !=
1862  atomicReductionEntryBlock.getArgumentTypes()[1])
1863  return emitOpError() << "expects atomic reduction region with two "
1864  "arguments of the same type";
1865  auto ptrType = llvm::dyn_cast<PointerLikeType>(
1866  atomicReductionEntryBlock.getArgumentTypes()[0]);
1867  if (!ptrType ||
1868  (ptrType.getElementType() && ptrType.getElementType() != getType()))
1869  return emitOpError() << "expects atomic reduction region arguments to "
1870  "be accumulators containing the reduction type";
1871  }
1872 
1873  if (getCleanupRegion().empty())
1874  return success();
1875  Block &cleanupEntryBlock = getCleanupRegion().front();
1876  if (cleanupEntryBlock.getNumArguments() != 1 ||
1877  cleanupEntryBlock.getArgument(0).getType() != getType())
1878  return emitOpError() << "expects cleanup region with one argument "
1879  "of the reduction type";
1880 
1881  return success();
1882 }
1883 
1884 //===----------------------------------------------------------------------===//
1885 // TaskOp
1886 //===----------------------------------------------------------------------===//
1887 
1888 void TaskOp::build(OpBuilder &builder, OperationState &state,
1889  const TaskClauseOps &clauses) {
1890  MLIRContext *ctx = builder.getContext();
1891  // TODO Store clauses in op: privateVars, privatizers.
1892  TaskOp::build(
1893  builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1894  clauses.mergeableAttr, clauses.inReductionVars,
1895  makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
1896  makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1897  clauses.allocateVars, clauses.allocatorVars);
1898 }
1899 
1901  LogicalResult verifyDependVars =
1902  verifyDependVarList(*this, getDepends(), getDependVars());
1903  return failed(verifyDependVars)
1904  ? verifyDependVars
1905  : verifyReductionVarList(*this, getInReductions(),
1906  getInReductionVars());
1907 }
1908 
1909 //===----------------------------------------------------------------------===//
1910 // TaskgroupOp
1911 //===----------------------------------------------------------------------===//
1912 
1913 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
1914  const TaskgroupClauseOps &clauses) {
1915  MLIRContext *ctx = builder.getContext();
1916  TaskgroupOp::build(builder, state, clauses.taskReductionVars,
1917  makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
1918  clauses.allocateVars, clauses.allocatorVars);
1919 }
1920 
1922  return verifyReductionVarList(*this, getTaskReductions(),
1923  getTaskReductionVars());
1924 }
1925 
1926 //===----------------------------------------------------------------------===//
1927 // TaskloopOp
1928 //===----------------------------------------------------------------------===//
1929 
1930 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
1931  const TaskloopClauseOps &clauses) {
1932  MLIRContext *ctx = builder.getContext();
1933  // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1934  TaskloopOp::build(
1935  builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1936  clauses.mergeableAttr, clauses.inReductionVars,
1937  makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
1938  makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
1939  clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
1940  clauses.numTasksVar, clauses.nogroupAttr);
1941 }
1942 
1943 SmallVector<Value> TaskloopOp::getAllReductionVars() {
1944  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
1945  getInReductionVars().end());
1946  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1947  getReductionVars().end());
1948  return allReductionNvars;
1949 }
1950 
1952  if (getAllocateVars().size() != getAllocatorsVars().size())
1953  return emitError(
1954  "expected equal sizes for allocate and allocator variables");
1955  if (failed(
1956  verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
1957  failed(verifyReductionVarList(*this, getInReductions(),
1958  getInReductionVars())))
1959  return failure();
1960 
1961  if (!getReductionVars().empty() && getNogroup())
1962  return emitError("if a reduction clause is present on the taskloop "
1963  "directive, the nogroup clause must not be specified");
1964  for (auto var : getReductionVars()) {
1965  if (llvm::is_contained(getInReductionVars(), var))
1966  return emitError("the same list item cannot appear in both a reduction "
1967  "and an in_reduction clause");
1968  }
1969 
1970  if (getGrainSize() && getNumTasks()) {
1971  return emitError(
1972  "the grainsize clause and num_tasks clause are mutually exclusive and "
1973  "may not appear on the same taskloop directive");
1974  }
1975 
1976  if (!isWrapper())
1977  return emitOpError() << "must be a loop wrapper";
1978 
1979  if (LoopWrapperInterface nested = getNestedWrapper()) {
1980  // Check for the allowed leaf constructs that may appear in a composite
1981  // construct directly after TASKLOOP.
1982  if (!isa<SimdOp>(nested))
1983  return emitError() << "only supported nested wrapper is 'omp.simd'";
1984  }
1985  return success();
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // LoopNestOp
1990 //===----------------------------------------------------------------------===//
1991 
1993  // Parse an opening `(` followed by induction variables followed by `)`
1996  Type loopVarType;
1998  parser.parseColonType(loopVarType) ||
1999  // Parse loop bounds.
2000  parser.parseEqual() ||
2001  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2002  parser.parseKeyword("to") ||
2003  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2004  return failure();
2005 
2006  for (auto &iv : ivs)
2007  iv.type = loopVarType;
2008 
2009  // Parse "inclusive" flag.
2010  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2011  result.addAttribute("inclusive",
2012  UnitAttr::get(parser.getBuilder().getContext()));
2013 
2014  // Parse step values.
2016  if (parser.parseKeyword("step") ||
2017  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2018  return failure();
2019 
2020  // Parse the body.
2021  Region *region = result.addRegion();
2022  if (parser.parseRegion(*region, ivs))
2023  return failure();
2024 
2025  // Resolve operands.
2026  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2027  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2028  parser.resolveOperands(steps, loopVarType, result.operands))
2029  return failure();
2030 
2031  // Parse the optional attribute list.
2032  return parser.parseOptionalAttrDict(result.attributes);
2033 }
2034 
2036  Region &region = getRegion();
2037  auto args = region.getArguments();
2038  p << " (" << args << ") : " << args[0].getType() << " = (" << getLowerBound()
2039  << ") to (" << getUpperBound() << ") ";
2040  if (getInclusive())
2041  p << "inclusive ";
2042  p << "step (" << getStep() << ") ";
2043  p.printRegion(region, /*printEntryBlockArgs=*/false);
2044 }
2045 
2046 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2047  const LoopNestClauseOps &clauses) {
2048  LoopNestOp::build(builder, state, clauses.loopLBVar, clauses.loopUBVar,
2049  clauses.loopStepVar, clauses.loopInclusiveAttr);
2050 }
2051 
2053  if (getLowerBound().empty())
2054  return emitOpError() << "must represent at least one loop";
2055 
2056  if (getLowerBound().size() != getIVs().size())
2057  return emitOpError() << "number of range arguments and IVs do not match";
2058 
2059  for (auto [lb, iv] : llvm::zip_equal(getLowerBound(), getIVs())) {
2060  if (lb.getType() != iv.getType())
2061  return emitOpError()
2062  << "range argument type does not match corresponding IV type";
2063  }
2064 
2065  auto wrapper =
2066  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2067 
2068  if (!wrapper || !wrapper.isWrapper())
2069  return emitOpError() << "expects parent op to be a valid loop wrapper";
2070 
2071  return success();
2072 }
2073 
2074 void LoopNestOp::gatherWrappers(
2076  Operation *parent = (*this)->getParentOp();
2077  while (auto wrapper =
2078  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2079  if (!wrapper.isWrapper())
2080  break;
2081  wrappers.push_back(wrapper);
2082  parent = parent->getParentOp();
2083  }
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // Critical construct (2.17.1)
2088 //===----------------------------------------------------------------------===//
2089 
2090 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2091  const CriticalClauseOps &clauses) {
2092  CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr);
2093 }
2094 
2096  return verifySynchronizationHint(*this, getHintVal());
2097 }
2098 
2099 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2100  if (getNameAttr()) {
2101  SymbolRefAttr symbolRef = getNameAttr();
2102  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2103  *this, symbolRef);
2104  if (!decl) {
2105  return emitOpError() << "expected symbol reference " << symbolRef
2106  << " to point to a critical declaration";
2107  }
2108  }
2109 
2110  return success();
2111 }
2112 
2113 //===----------------------------------------------------------------------===//
2114 // Ordered construct
2115 //===----------------------------------------------------------------------===//
2116 
2118  bool hasRegion = op.getNumRegions() > 0;
2119  auto loopOp = op.getParentOfType<LoopNestOp>();
2120  if (!loopOp) {
2121  if (hasRegion)
2122  return success();
2123 
2124  // TODO: Consider if this needs to be the case only for the standalone
2125  // variant of the ordered construct.
2126  return op.emitOpError() << "must be nested inside of a loop";
2127  }
2128 
2129  Operation *wrapper = loopOp->getParentOp();
2130  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2131  IntegerAttr orderedAttr = wsloopOp.getOrderedValAttr();
2132  if (!orderedAttr)
2133  return op.emitOpError() << "the enclosing worksharing-loop region must "
2134  "have an ordered clause";
2135 
2136  if (hasRegion && orderedAttr.getInt() != 0)
2137  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2138  "have a parameter present";
2139 
2140  if (!hasRegion && orderedAttr.getInt() == 0)
2141  return op.emitOpError() << "the enclosing loop's ordered clause must "
2142  "have a parameter present";
2143  } else if (!isa<SimdOp>(wrapper)) {
2144  return op.emitOpError() << "must be nested inside of a worksharing, simd "
2145  "or worksharing simd loop";
2146  }
2147  return success();
2148 }
2149 
2150 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2151  const OrderedOpClauseOps &clauses) {
2152  OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
2153  clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
2154 }
2155 
2157  if (failed(verifyOrderedParent(**this)))
2158  return failure();
2159 
2160  auto wrapper = (*this)->getParentOfType<WsloopOp>();
2161  if (!wrapper || *wrapper.getOrderedVal() != *getNumLoopsVal())
2162  return emitOpError() << "number of variables in depend clause does not "
2163  << "match number of iteration variables in the "
2164  << "doacross loop";
2165 
2166  return success();
2167 }
2168 
2169 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2170  const OrderedRegionClauseOps &clauses) {
2171  OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
2172 }
2173 
2175  // TODO: The code generation for ordered simd directive is not supported yet.
2176  if (getSimd())
2177  return failure();
2178 
2179  return verifyOrderedParent(**this);
2180 }
2181 
2182 //===----------------------------------------------------------------------===//
2183 // TaskwaitOp
2184 //===----------------------------------------------------------------------===//
2185 
2186 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2187  const TaskwaitClauseOps &clauses) {
2188  // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr.
2189  TaskwaitOp::build(builder, state);
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // Verifier for AtomicReadOp
2194 //===----------------------------------------------------------------------===//
2195 
2197  if (verifyCommon().failed())
2198  return mlir::failure();
2199 
2200  if (auto mo = getMemoryOrderVal()) {
2201  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2202  *mo == ClauseMemoryOrderKind::Release) {
2203  return emitError(
2204  "memory-order must not be acq_rel or release for atomic reads");
2205  }
2206  }
2207  return verifySynchronizationHint(*this, getHintVal());
2208 }
2209 
2210 //===----------------------------------------------------------------------===//
2211 // Verifier for AtomicWriteOp
2212 //===----------------------------------------------------------------------===//
2213 
2215  if (verifyCommon().failed())
2216  return mlir::failure();
2217 
2218  if (auto mo = getMemoryOrderVal()) {
2219  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2220  *mo == ClauseMemoryOrderKind::Acquire) {
2221  return emitError(
2222  "memory-order must not be acq_rel or acquire for atomic writes");
2223  }
2224  }
2225  return verifySynchronizationHint(*this, getHintVal());
2226 }
2227 
2228 //===----------------------------------------------------------------------===//
2229 // Verifier for AtomicUpdateOp
2230 //===----------------------------------------------------------------------===//
2231 
2232 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2233  PatternRewriter &rewriter) {
2234  if (op.isNoOp()) {
2235  rewriter.eraseOp(op);
2236  return success();
2237  }
2238  if (Value writeVal = op.getWriteOpVal()) {
2239  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
2240  op.getHintValAttr(),
2241  op.getMemoryOrderValAttr());
2242  return success();
2243  }
2244  return failure();
2245 }
2246 
2248  if (verifyCommon().failed())
2249  return mlir::failure();
2250 
2251  if (auto mo = getMemoryOrderVal()) {
2252  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2253  *mo == ClauseMemoryOrderKind::Acquire) {
2254  return emitError(
2255  "memory-order must not be acq_rel or acquire for atomic updates");
2256  }
2257  }
2258 
2259  return verifySynchronizationHint(*this, getHintVal());
2260 }
2261 
2262 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2263 
2264 //===----------------------------------------------------------------------===//
2265 // Verifier for AtomicCaptureOp
2266 //===----------------------------------------------------------------------===//
2267 
2268 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2269  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2270  return op;
2271  return dyn_cast<AtomicReadOp>(getSecondOp());
2272 }
2273 
2274 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2275  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2276  return op;
2277  return dyn_cast<AtomicWriteOp>(getSecondOp());
2278 }
2279 
2280 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2281  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2282  return op;
2283  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2284 }
2285 
2287  return verifySynchronizationHint(*this, getHintVal());
2288 }
2289 
2290 LogicalResult AtomicCaptureOp::verifyRegions() {
2291  if (verifyRegionsCommon().failed())
2292  return mlir::failure();
2293 
2294  if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
2295  return emitOpError(
2296  "operations inside capture region must not have hint clause");
2297 
2298  if (getFirstOp()->getAttr("memory_order_val") ||
2299  getSecondOp()->getAttr("memory_order_val"))
2300  return emitOpError(
2301  "operations inside capture region must not have memory_order clause");
2302  return success();
2303 }
2304 
2305 //===----------------------------------------------------------------------===//
2306 // Verifier for CancelOp
2307 //===----------------------------------------------------------------------===//
2308 
2310  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2311  Operation *parentOp = (*this)->getParentOp();
2312 
2313  if (!parentOp) {
2314  return emitOpError() << "must be used within a region supporting "
2315  "cancel directive";
2316  }
2317 
2318  if ((cct == ClauseCancellationConstructType::Parallel) &&
2319  !isa<ParallelOp>(parentOp)) {
2320  return emitOpError() << "cancel parallel must appear "
2321  << "inside a parallel region";
2322  }
2323  if (cct == ClauseCancellationConstructType::Loop) {
2324  auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2325  auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2326  loopOp ? loopOp->getParentOp() : nullptr);
2327 
2328  if (!wsloopOp) {
2329  return emitOpError()
2330  << "cancel loop must appear inside a worksharing-loop region";
2331  }
2332  if (wsloopOp.getNowaitAttr()) {
2333  return emitError() << "A worksharing construct that is canceled "
2334  << "must not have a nowait clause";
2335  }
2336  if (wsloopOp.getOrderedValAttr()) {
2337  return emitError() << "A worksharing construct that is canceled "
2338  << "must not have an ordered clause";
2339  }
2340 
2341  } else if (cct == ClauseCancellationConstructType::Sections) {
2342  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2343  return emitOpError() << "cancel sections must appear "
2344  << "inside a sections region";
2345  }
2346  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2347  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2348  return emitError() << "A sections construct that is canceled "
2349  << "must not have a nowait clause";
2350  }
2351  }
2352  // TODO : Add more when we support taskgroup.
2353  return success();
2354 }
2355 //===----------------------------------------------------------------------===//
2356 // Verifier for CancelOp
2357 //===----------------------------------------------------------------------===//
2358 
2360  ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2361  Operation *parentOp = (*this)->getParentOp();
2362 
2363  if (!parentOp) {
2364  return emitOpError() << "must be used within a region supporting "
2365  "cancellation point directive";
2366  }
2367 
2368  if ((cct == ClauseCancellationConstructType::Parallel) &&
2369  !(isa<ParallelOp>(parentOp))) {
2370  return emitOpError() << "cancellation point parallel must appear "
2371  << "inside a parallel region";
2372  }
2373  if ((cct == ClauseCancellationConstructType::Loop) &&
2374  (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
2375  return emitOpError() << "cancellation point loop must appear "
2376  << "inside a worksharing-loop region";
2377  }
2378  if ((cct == ClauseCancellationConstructType::Sections) &&
2379  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2380  return emitOpError() << "cancellation point sections must appear "
2381  << "inside a sections region";
2382  }
2383  // TODO : Add more when we support taskgroup.
2384  return success();
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // MapBoundsOp
2389 //===----------------------------------------------------------------------===//
2390 
2392  auto extent = getExtent();
2393  auto upperbound = getUpperBound();
2394  if (!extent && !upperbound)
2395  return emitError("expected extent or upperbound.");
2396  return success();
2397 }
2398 
2399 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2400  TypeRange /*result_types*/, StringAttr symName,
2401  TypeAttr type) {
2402  PrivateClauseOp::build(
2403  odsBuilder, odsState, symName, type,
2405  DataSharingClauseType::Private));
2406 }
2407 
2409  Type symType = getType();
2410 
2411  auto verifyTerminator = [&](Operation *terminator,
2412  bool yieldsValue) -> LogicalResult {
2413  if (!terminator->getBlock()->getSuccessors().empty())
2414  return success();
2415 
2416  if (!llvm::isa<YieldOp>(terminator))
2417  return mlir::emitError(terminator->getLoc())
2418  << "expected exit block terminator to be an `omp.yield` op.";
2419 
2420  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2421  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2422 
2423  if (!yieldsValue) {
2424  if (yieldedTypes.empty())
2425  return success();
2426 
2427  return mlir::emitError(terminator->getLoc())
2428  << "Did not expect any values to be yielded.";
2429  }
2430 
2431  if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2432  return success();
2433 
2434  auto error = mlir::emitError(yieldOp.getLoc())
2435  << "Invalid yielded value. Expected type: " << symType
2436  << ", got: ";
2437 
2438  if (yieldedTypes.empty())
2439  error << "None";
2440  else
2441  error << yieldedTypes;
2442 
2443  return error;
2444  };
2445 
2446  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
2447  StringRef regionName,
2448  bool yieldsValue) -> LogicalResult {
2449  assert(!region.empty());
2450 
2451  if (region.getNumArguments() != expectedNumArgs)
2452  return mlir::emitError(region.getLoc())
2453  << "`" << regionName << "`: "
2454  << "expected " << expectedNumArgs
2455  << " region arguments, got: " << region.getNumArguments();
2456 
2457  for (Block &block : region) {
2458  // MLIR will verify the absence of the terminator for us.
2459  if (!block.mightHaveTerminator())
2460  continue;
2461 
2462  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2463  return failure();
2464  }
2465 
2466  return success();
2467  };
2468 
2469  if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc",
2470  /*yieldsValue=*/true)))
2471  return failure();
2472 
2473  DataSharingClauseType dsType = getDataSharingType();
2474 
2475  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2476  return emitError("`private` clauses require only an `alloc` region.");
2477 
2478  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2479  return emitError(
2480  "`firstprivate` clauses require both `alloc` and `copy` regions.");
2481 
2482  if (dsType == DataSharingClauseType::FirstPrivate &&
2483  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
2484  /*yieldsValue=*/true)))
2485  return failure();
2486 
2487  if (!getDeallocRegion().empty() &&
2488  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
2489  /*yieldsValue=*/false)))
2490  return failure();
2491 
2492  return success();
2493 }
2494 
2495 #define GET_ATTRDEF_CLASSES
2496 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2497 
2498 #define GET_OP_CLASSES
2499 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2500 
2501 #define GET_TYPEDEF_CLASSES
2502 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:716
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:708
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static void printPrivateList(OpAsmPrinter &p, Operation *op, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedItems, SmallVectorImpl< Type > &types, ArrayAttr &alignmentValues)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > depends)
Print Depend clause.
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCapture)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &vars, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &stepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, TypeRange typesAllocate, OperandRange varsAllocator, TypeRange typesAllocator)
Print allocate clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignmentValues, OperandRange alignedVariables)
static ParseResult parseParallelRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVarOperands, SmallVectorImpl< Type > &reductionVarTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVarOperands, llvm::SmallVectorImpl< Type > &privateVarsTypes, ArrayAttr &privatizerSymbols)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< ArrayAttr > reductions)
Print Reduction clause.
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocate, SmallVectorImpl< Type > &typesAllocate, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operandsAllocator, SmallVectorImpl< Type > &typesAllocator)
Parse an allocate clause with allocators and a list of operands with types.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedVarTypes, std::optional< ArrayAttr > alignmentValues)
Print Aligned Clause.
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductions, OperandRange reductionVars, std::optional< ArrayRef< bool >> byRef=std::nullopt)
Verifies Reduction Clause.
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands)
void printWsloop(OpAsmPrinter &p, Operation *op, Region &region, ValueRange reductionOperands, TypeRange reductionTypes, DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > depends, OperandRange dependVars)
Verifies Depend clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, DenseBoolArrayAttr &isByRef, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs)
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange reductionVarOperands, TypeRange reductionVarTypes, DenseBoolArrayAttr reductionVarIsByRef, ArrayAttr reductionSymbols, ValueRange privateVarOperands, TypeRange privateVarTypes, ArrayAttr privatizerSymbols)
static void printAtomicReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region &region)
ParseResult parseWsloop(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols)
static ParseResult parseCopyPrivateVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &copyPrivateSymbols)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op, OperandRange copyPrivateVars, TypeRange copyPrivateTypes, std::optional< ArrayAttr > copyPrivateFuncs)
Print CopyPrivate clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearVarTypes, ValueRange linearStepVars)
Print Linear Clause.
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &redcuctionSymbols)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, DenseBoolArrayAttr byRef, ArrayAttr symbols)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, DenseIntElementsAttr membersIdx)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr schedAttr, ScheduleModifierAttr modifier, UnitAttr simd, Value scheduleChunkVar, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseCleanupReductionRegion(OpAsmParser &parser, Region &region)
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, Region &region)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static ParseResult parsePrivateList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateOperands, SmallVectorImpl< Type > &privateOperandTypes, ArrayAttr &privatizerSymbols)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static LogicalResult verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars, std::optional< ArrayAttr > copyPrivateFuncs)
Verifies CopyPrivate Clause.
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, ArrayAttr &dependsArray)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, SmallVectorImpl< Type > &mapOperandTypes)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapOperands, TypeRange mapOperandTypes)
static void printCleanupReductionRegion(OpAsmPrinter &printer, DeclareReductionOp op, Region &region)
static ParseResult parseMembersIndex(OpAsmParser &parser, DenseIntElementsAttr &membersIdx)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
Definition: Block.h:31
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
SuccessorRange getSuccessors()
Definition: Block.h:265
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Define a fold interface to allow for dialects to control specific aspects of the folding behavior for...
DialectFoldInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:209
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.