MLIR  20.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/STLForwardCompat.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Frontend/OpenMP/OMPConstants.h"
33 #include <cstddef>
34 #include <iterator>
35 #include <optional>
36 #include <variant>
37 
38 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
39 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
40 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
42 
43 using namespace mlir;
44 using namespace mlir::omp;
45 
46 static ArrayAttr makeArrayAttr(MLIRContext *context,
48  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
49 }
50 
51 static DenseBoolArrayAttr
53  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
54 }
55 
56 namespace {
57 struct MemRefPointerLikeModel
58  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
59  MemRefType> {
60  Type getElementType(Type pointer) const {
61  return llvm::cast<MemRefType>(pointer).getElementType();
62  }
63 };
64 
65 struct LLVMPointerPointerLikeModel
66  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
67  LLVM::LLVMPointerType> {
68  Type getElementType(Type pointer) const { return Type(); }
69 };
70 } // namespace
71 
72 void OpenMPDialect::initialize() {
73  addOperations<
74 #define GET_OP_LIST
75 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
76  >();
77  addAttributes<
78 #define GET_ATTRDEF_LIST
79 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
80  >();
81  addTypes<
82 #define GET_TYPEDEF_LIST
83 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
84  >();
85 
86  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
87  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
88  *getContext());
89 
90  // Attach default offload module interface to module op to access
91  // offload functionality through
92  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
93  *getContext());
94 
95  // Attach default declare target interfaces to operations which can be marked
96  // as declare target (Global Operations and Functions/Subroutines in dialects
97  // that Fortran (or other languages that lower to MLIR) translates too
98  mlir::LLVM::GlobalOp::attachInterface<
100  *getContext());
101  mlir::LLVM::LLVMFuncOp::attachInterface<
103  *getContext());
104  mlir::func::FuncOp::attachInterface<
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Parser and printer for Allocate Clause
110 //===----------------------------------------------------------------------===//
111 
112 /// Parse an allocate clause with allocators and a list of operands with types.
113 ///
114 /// allocate-operand-list :: = allocate-operand |
115 /// allocator-operand `,` allocate-operand-list
116 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
117 /// ssa-id-and-type ::= ssa-id `:` type
118 static ParseResult parseAllocateAndAllocator(
119  OpAsmParser &parser,
121  SmallVectorImpl<Type> &allocateTypes,
123  SmallVectorImpl<Type> &allocatorTypes) {
124 
125  return parser.parseCommaSeparatedList([&]() {
127  Type type;
128  if (parser.parseOperand(operand) || parser.parseColonType(type))
129  return failure();
130  allocatorVars.push_back(operand);
131  allocatorTypes.push_back(type);
132  if (parser.parseArrow())
133  return failure();
134  if (parser.parseOperand(operand) || parser.parseColonType(type))
135  return failure();
136 
137  allocateVars.push_back(operand);
138  allocateTypes.push_back(type);
139  return success();
140  });
141 }
142 
143 /// Print allocate clause
145  OperandRange allocateVars,
146  TypeRange allocateTypes,
147  OperandRange allocatorVars,
148  TypeRange allocatorTypes) {
149  for (unsigned i = 0; i < allocateVars.size(); ++i) {
150  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
151  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
152  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
153  }
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Parser and printer for a clause attribute (StringEnumAttr)
158 //===----------------------------------------------------------------------===//
159 
160 template <typename ClauseAttr>
161 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
162  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
163  StringRef enumStr;
164  SMLoc loc = parser.getCurrentLocation();
165  if (parser.parseKeyword(&enumStr))
166  return failure();
167  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
168  attr = ClauseAttr::get(parser.getContext(), *enumValue);
169  return success();
170  }
171  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
172 }
173 
174 template <typename ClauseAttr>
175 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
176  p << stringifyEnum(attr.getValue());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Parser and printer for Linear Clause
181 //===----------------------------------------------------------------------===//
182 
183 /// linear ::= `linear` `(` linear-list `)`
184 /// linear-list := linear-val | linear-val linear-list
185 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
186 static ParseResult parseLinearClause(
187  OpAsmParser &parser,
189  SmallVectorImpl<Type> &linearTypes,
191  return parser.parseCommaSeparatedList([&]() {
193  Type type;
195  if (parser.parseOperand(var) || parser.parseEqual() ||
196  parser.parseOperand(stepVar) || parser.parseColonType(type))
197  return failure();
198 
199  linearVars.push_back(var);
200  linearTypes.push_back(type);
201  linearStepVars.push_back(stepVar);
202  return success();
203  });
204 }
205 
206 /// Print Linear Clause
208  ValueRange linearVars, TypeRange linearTypes,
209  ValueRange linearStepVars) {
210  size_t linearVarsSize = linearVars.size();
211  for (unsigned i = 0; i < linearVarsSize; ++i) {
212  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
213  p << linearVars[i];
214  if (linearStepVars.size() > i)
215  p << " = " << linearStepVars[i];
216  p << " : " << linearVars[i].getType() << separator;
217  }
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Verifier for Nontemporal Clause
222 //===----------------------------------------------------------------------===//
223 
224 static LogicalResult verifyNontemporalClause(Operation *op,
225  OperandRange nontemporalVars) {
226 
227  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
228  DenseSet<Value> nontemporalItems;
229  for (const auto &it : nontemporalVars)
230  if (!nontemporalItems.insert(it).second)
231  return op->emitOpError() << "nontemporal variable used more than once";
232 
233  return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // Parser, verifier and printer for Aligned Clause
238 //===----------------------------------------------------------------------===//
239 static LogicalResult verifyAlignedClause(Operation *op,
240  std::optional<ArrayAttr> alignments,
241  OperandRange alignedVars) {
242  // Check if number of alignment values equals to number of aligned variables
243  if (!alignedVars.empty()) {
244  if (!alignments || alignments->size() != alignedVars.size())
245  return op->emitOpError()
246  << "expected as many alignment values as aligned variables";
247  } else {
248  if (alignments)
249  return op->emitOpError() << "unexpected alignment values attribute";
250  return success();
251  }
252 
253  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
254  DenseSet<Value> alignedItems;
255  for (auto it : alignedVars)
256  if (!alignedItems.insert(it).second)
257  return op->emitOpError() << "aligned variable used more than once";
258 
259  if (!alignments)
260  return success();
261 
262  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
263  for (unsigned i = 0; i < (*alignments).size(); ++i) {
264  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
265  if (intAttr.getValue().sle(0))
266  return op->emitOpError() << "alignment should be greater than 0";
267  } else {
268  return op->emitOpError() << "expected integer alignment";
269  }
270  }
271 
272  return success();
273 }
274 
275 /// aligned ::= `aligned` `(` aligned-list `)`
276 /// aligned-list := aligned-val | aligned-val aligned-list
277 /// aligned-val := ssa-id-and-type `->` alignment
278 static ParseResult
281  SmallVectorImpl<Type> &alignedTypes,
282  ArrayAttr &alignmentsAttr) {
283  SmallVector<Attribute> alignmentVec;
284  if (failed(parser.parseCommaSeparatedList([&]() {
285  if (parser.parseOperand(alignedVars.emplace_back()) ||
286  parser.parseColonType(alignedTypes.emplace_back()) ||
287  parser.parseArrow() ||
288  parser.parseAttribute(alignmentVec.emplace_back())) {
289  return failure();
290  }
291  return success();
292  })))
293  return failure();
294  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
295  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
296  return success();
297 }
298 
299 /// Print Aligned Clause
301  ValueRange alignedVars, TypeRange alignedTypes,
302  std::optional<ArrayAttr> alignments) {
303  for (unsigned i = 0; i < alignedVars.size(); ++i) {
304  if (i != 0)
305  p << ", ";
306  p << alignedVars[i] << " : " << alignedVars[i].getType();
307  p << " -> " << (*alignments)[i];
308  }
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // Parser, printer and verifier for Schedule Clause
313 //===----------------------------------------------------------------------===//
314 
315 static ParseResult
317  SmallVectorImpl<SmallString<12>> &modifiers) {
318  if (modifiers.size() > 2)
319  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
320  for (const auto &mod : modifiers) {
321  // Translate the string. If it has no value, then it was not a valid
322  // modifier!
323  auto symbol = symbolizeScheduleModifier(mod);
324  if (!symbol)
325  return parser.emitError(parser.getNameLoc())
326  << " unknown modifier type: " << mod;
327  }
328 
329  // If we have one modifier that is "simd", then stick a "none" modiifer in
330  // index 0.
331  if (modifiers.size() == 1) {
332  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
333  modifiers.push_back(modifiers[0]);
334  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
335  }
336  } else if (modifiers.size() == 2) {
337  // If there are two modifier:
338  // First modifier should not be simd, second one should be simd
339  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
340  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
341  return parser.emitError(parser.getNameLoc())
342  << " incorrect modifier order";
343  }
344  return success();
345 }
346 
347 /// schedule ::= `schedule` `(` sched-list `)`
348 /// sched-list ::= sched-val | sched-val sched-list |
349 /// sched-val `,` sched-modifier
350 /// sched-val ::= sched-with-chunk | sched-wo-chunk
351 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
352 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
353 /// sched-wo-chunk ::= `auto` | `runtime`
354 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
355 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
356 static ParseResult
357 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
358  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
359  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
360  Type &chunkType) {
361  StringRef keyword;
362  if (parser.parseKeyword(&keyword))
363  return failure();
364  std::optional<mlir::omp::ClauseScheduleKind> schedule =
365  symbolizeClauseScheduleKind(keyword);
366  if (!schedule)
367  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
368 
369  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
370  switch (*schedule) {
371  case ClauseScheduleKind::Static:
372  case ClauseScheduleKind::Dynamic:
373  case ClauseScheduleKind::Guided:
374  if (succeeded(parser.parseOptionalEqual())) {
375  chunkSize = OpAsmParser::UnresolvedOperand{};
376  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
377  return failure();
378  } else {
379  chunkSize = std::nullopt;
380  }
381  break;
382  case ClauseScheduleKind::Auto:
384  chunkSize = std::nullopt;
385  }
386 
387  // If there is a comma, we have one or more modifiers..
388  SmallVector<SmallString<12>> modifiers;
389  while (succeeded(parser.parseOptionalComma())) {
390  StringRef mod;
391  if (parser.parseKeyword(&mod))
392  return failure();
393  modifiers.push_back(mod);
394  }
395 
396  if (verifyScheduleModifiers(parser, modifiers))
397  return failure();
398 
399  if (!modifiers.empty()) {
400  SMLoc loc = parser.getCurrentLocation();
401  if (std::optional<ScheduleModifier> mod =
402  symbolizeScheduleModifier(modifiers[0])) {
403  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
404  } else {
405  return parser.emitError(loc, "invalid schedule modifier");
406  }
407  // Only SIMD attribute is allowed here!
408  if (modifiers.size() > 1) {
409  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
410  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
411  }
412  }
413 
414  return success();
415 }
416 
417 /// Print schedule clause
419  ClauseScheduleKindAttr scheduleKind,
420  ScheduleModifierAttr scheduleMod,
421  UnitAttr scheduleSimd, Value scheduleChunk,
422  Type scheduleChunkType) {
423  p << stringifyClauseScheduleKind(scheduleKind.getValue());
424  if (scheduleChunk)
425  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
426  if (scheduleMod)
427  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
428  if (scheduleSimd)
429  p << ", simd";
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // Parser and printer for Order Clause
434 //===----------------------------------------------------------------------===//
435 
436 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
437 // order-modifier ::= reproducible | unconstrained
438 static ParseResult parseOrderClause(OpAsmParser &parser,
439  ClauseOrderKindAttr &order,
440  OrderModifierAttr &orderMod) {
441  StringRef enumStr;
442  SMLoc loc = parser.getCurrentLocation();
443  if (parser.parseKeyword(&enumStr))
444  return failure();
445  if (std::optional<OrderModifier> enumValue =
446  symbolizeOrderModifier(enumStr)) {
447  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
448  if (parser.parseOptionalColon())
449  return failure();
450  loc = parser.getCurrentLocation();
451  if (parser.parseKeyword(&enumStr))
452  return failure();
453  }
454  if (std::optional<ClauseOrderKind> enumValue =
455  symbolizeClauseOrderKind(enumStr)) {
456  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
457  return success();
458  }
459  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
460 }
461 
463  ClauseOrderKindAttr order,
464  OrderModifierAttr orderMod) {
465  if (orderMod)
466  p << stringifyOrderModifier(orderMod.getValue()) << ":";
467  if (order)
468  p << stringifyClauseOrderKind(order.getValue());
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // Parser, printer and verifier for ReductionVarList
473 //===----------------------------------------------------------------------===//
474 
475 static ParseResult parseClauseWithRegionArgs(
476  OpAsmParser &parser, Region &region,
478  SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
479  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
480  SmallVector<SymbolRefAttr> reductionVec;
481  SmallVector<bool> isByRefVec;
482  unsigned regionArgOffset = regionPrivateArgs.size();
483 
484  if (failed(
486  ParseResult optionalByref = parser.parseOptionalKeyword("byref");
487  if (parser.parseAttribute(reductionVec.emplace_back()) ||
488  parser.parseOperand(operands.emplace_back()) ||
489  parser.parseArrow() ||
490  parser.parseArgument(regionPrivateArgs.emplace_back()) ||
491  parser.parseColonType(types.emplace_back()))
492  return failure();
493  isByRefVec.push_back(optionalByref.succeeded());
494  return success();
495  })))
496  return failure();
497  byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
498 
499  auto *argsBegin = regionPrivateArgs.begin();
500  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
501  argsBegin + regionArgOffset + types.size());
502  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
503  prv.type = type;
504  }
505  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
506  symbols = ArrayAttr::get(parser.getContext(), reductions);
507  return success();
508 }
509 
511  ValueRange argsSubrange,
512  StringRef clauseName, ValueRange operands,
513  TypeRange types, DenseBoolArrayAttr byref,
514  ArrayAttr symbols) {
515  if (!clauseName.empty())
516  p << clauseName << "(";
517 
518  llvm::interleaveComma(llvm::zip_equal(symbols, operands, argsSubrange, types,
519  byref.asArrayRef()),
520  p, [&p](auto t) {
521  auto [sym, op, arg, type, isByRef] = t;
522  p << (isByRef ? "byref " : "") << sym << " " << op
523  << " -> " << arg << " : " << type;
524  });
525 
526  if (!clauseName.empty())
527  p << ") ";
528 }
529 
530 static ParseResult parseParallelRegion(
531  OpAsmParser &parser, Region &region,
533  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
534  ArrayAttr &reductionSyms,
536  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
538 
539  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
540  if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
541  reductionTypes, reductionByref,
542  reductionSyms, regionPrivateArgs)))
543  return failure();
544  }
545 
546  if (succeeded(parser.parseOptionalKeyword("private"))) {
547  auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
548  if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
549  privateTypes, privateByref,
550  privateSyms, regionPrivateArgs)))
551  return failure();
552  if (llvm::any_of(privateByref.asArrayRef(),
553  [](bool byref) { return byref; })) {
554  parser.emitError(parser.getCurrentLocation(),
555  "private clause cannot have byref attributes");
556  return failure();
557  }
558  }
559 
560  return parser.parseRegion(region, regionPrivateArgs);
561 }
562 
563 static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
564  ValueRange reductionVars,
565  TypeRange reductionTypes,
566  DenseBoolArrayAttr reductionByref,
567  ArrayAttr reductionSyms, ValueRange privateVars,
568  TypeRange privateTypes, ArrayAttr privateSyms) {
569  if (reductionSyms) {
570  auto *argsBegin = region.front().getArguments().begin();
571  MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
572  printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
573  reductionTypes, reductionByref, reductionSyms);
574  }
575 
576  if (privateSyms) {
577  auto *argsBegin = region.front().getArguments().begin();
578  MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
579  argsBegin + reductionVars.size() +
580  privateTypes.size());
581  mlir::SmallVector<bool> isByRefVec;
582  isByRefVec.resize(privateTypes.size(), false);
583  DenseBoolArrayAttr isByRef =
584  makeDenseBoolArrayAttr(op->getContext(), isByRefVec);
585 
586  printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVars,
587  privateTypes, isByRef, privateSyms);
588  }
589 
590  p.printRegion(region, /*printEntryBlockArgs=*/false);
591 }
592 
593 /// reduction-entry-list ::= reduction-entry
594 /// | reduction-entry-list `,` reduction-entry
595 /// reduction-entry ::= (`byref`)? symbol-ref `->` ssa-id `:` type
596 static ParseResult parseReductionVarList(
597  OpAsmParser &parser,
599  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
600  ArrayAttr &reductionSyms) {
601  SmallVector<SymbolRefAttr> reductionVec;
602  SmallVector<bool> isByRefVec;
603  if (failed(parser.parseCommaSeparatedList([&]() {
604  ParseResult optionalByref = parser.parseOptionalKeyword("byref");
605  if (parser.parseAttribute(reductionVec.emplace_back()) ||
606  parser.parseArrow() ||
607  parser.parseOperand(reductionVars.emplace_back()) ||
608  parser.parseColonType(reductionTypes.emplace_back()))
609  return failure();
610  isByRefVec.push_back(optionalByref.succeeded());
611  return success();
612  })))
613  return failure();
614  reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
615  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
616  reductionSyms = ArrayAttr::get(parser.getContext(), reductions);
617  return success();
618 }
619 
620 /// Print Reduction clause
621 static void
623  OperandRange reductionVars, TypeRange reductionTypes,
624  std::optional<DenseBoolArrayAttr> reductionByref,
625  std::optional<ArrayAttr> reductionSyms) {
626  auto getByRef = [&](unsigned i) -> const char * {
627  if (!reductionByref || !*reductionByref)
628  return "";
629  assert(reductionByref->empty() || i < reductionByref->size());
630  if (!reductionByref->empty() && (*reductionByref)[i])
631  return "byref ";
632  return "";
633  };
634 
635  for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
636  if (i != 0)
637  p << ", ";
638  p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
639  << " : " << reductionVars[i].getType();
640  }
641 }
642 
643 /// Verifies Reduction Clause
644 static LogicalResult
645 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
646  OperandRange reductionVars,
647  std::optional<ArrayRef<bool>> reductionByref) {
648  if (!reductionVars.empty()) {
649  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
650  return op->emitOpError()
651  << "expected as many reduction symbol references "
652  "as reduction variables";
653  if (reductionByref && reductionByref->size() != reductionVars.size())
654  return op->emitError() << "expected as many reduction variable by "
655  "reference attributes as reduction variables";
656  } else {
657  if (reductionSyms)
658  return op->emitOpError() << "unexpected reduction symbol references";
659  return success();
660  }
661 
662  // TODO: The followings should be done in
663  // SymbolUserOpInterface::verifySymbolUses.
664  DenseSet<Value> accumulators;
665  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
666  Value accum = std::get<0>(args);
667 
668  if (!accumulators.insert(accum).second)
669  return op->emitOpError() << "accumulator variable used more than once";
670 
671  Type varType = accum.getType();
672  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
673  auto decl =
674  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
675  if (!decl)
676  return op->emitOpError() << "expected symbol reference " << symbolRef
677  << " to point to a reduction declaration";
678 
679  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
680  return op->emitOpError()
681  << "expected accumulator (" << varType
682  << ") to be the same type as reduction declaration ("
683  << decl.getAccumulatorType() << ")";
684  }
685 
686  return success();
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // Parser, printer and verifier for Copyprivate
691 //===----------------------------------------------------------------------===//
692 
693 /// copyprivate-entry-list ::= copyprivate-entry
694 /// | copyprivate-entry-list `,` copyprivate-entry
695 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
696 static ParseResult parseCopyprivate(
697  OpAsmParser &parser,
699  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
701  if (failed(parser.parseCommaSeparatedList([&]() {
702  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
703  parser.parseArrow() ||
704  parser.parseAttribute(symsVec.emplace_back()) ||
705  parser.parseColonType(copyprivateTypes.emplace_back()))
706  return failure();
707  return success();
708  })))
709  return failure();
710  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
711  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
712  return success();
713 }
714 
715 /// Print Copyprivate clause
717  OperandRange copyprivateVars,
718  TypeRange copyprivateTypes,
719  std::optional<ArrayAttr> copyprivateSyms) {
720  if (!copyprivateSyms.has_value())
721  return;
722  llvm::interleaveComma(
723  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
724  [&](const auto &args) {
725  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
726  << std::get<2>(args);
727  });
728 }
729 
730 /// Verifies CopyPrivate Clause
731 static LogicalResult
733  std::optional<ArrayAttr> copyprivateSyms) {
734  size_t copyprivateSymsSize =
735  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
736  if (copyprivateSymsSize != copyprivateVars.size())
737  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
738  << copyprivateVars.size()
739  << ") and functions (= " << copyprivateSymsSize
740  << "), both must be equal";
741  if (!copyprivateSyms.has_value())
742  return success();
743 
744  for (auto copyprivateVarAndSym :
745  llvm::zip(copyprivateVars, *copyprivateSyms)) {
746  auto symbolRef =
747  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
748  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
749  funcOp;
750  if (mlir::func::FuncOp mlirFuncOp =
751  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
752  symbolRef))
753  funcOp = mlirFuncOp;
754  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
755  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
756  op, symbolRef))
757  funcOp = llvmFuncOp;
758 
759  auto getNumArguments = [&] {
760  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
761  };
762 
763  auto getArgumentType = [&](unsigned i) {
764  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
765  *funcOp);
766  };
767 
768  if (!funcOp)
769  return op->emitOpError() << "expected symbol reference " << symbolRef
770  << " to point to a copy function";
771 
772  if (getNumArguments() != 2)
773  return op->emitOpError()
774  << "expected copy function " << symbolRef << " to have 2 operands";
775 
776  Type argTy = getArgumentType(0);
777  if (argTy != getArgumentType(1))
778  return op->emitOpError() << "expected copy function " << symbolRef
779  << " arguments to have the same type";
780 
781  Type varType = std::get<0>(copyprivateVarAndSym).getType();
782  if (argTy != varType)
783  return op->emitOpError()
784  << "expected copy function arguments' type (" << argTy
785  << ") to be the same as copyprivate variable's type (" << varType
786  << ")";
787  }
788 
789  return success();
790 }
791 
792 //===----------------------------------------------------------------------===//
793 // Parser, printer and verifier for DependVarList
794 //===----------------------------------------------------------------------===//
795 
796 /// depend-entry-list ::= depend-entry
797 /// | depend-entry-list `,` depend-entry
798 /// depend-entry ::= depend-kind `->` ssa-id `:` type
799 static ParseResult
802  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
804  if (failed(parser.parseCommaSeparatedList([&]() {
805  StringRef keyword;
806  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
807  parser.parseOperand(dependVars.emplace_back()) ||
808  parser.parseColonType(dependTypes.emplace_back()))
809  return failure();
810  if (std::optional<ClauseTaskDepend> keywordDepend =
811  (symbolizeClauseTaskDepend(keyword)))
812  kindsVec.emplace_back(
813  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
814  else
815  return failure();
816  return success();
817  })))
818  return failure();
819  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
820  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
821  return success();
822 }
823 
824 /// Print Depend clause
826  OperandRange dependVars, TypeRange dependTypes,
827  std::optional<ArrayAttr> dependKinds) {
828 
829  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
830  if (i != 0)
831  p << ", ";
832  p << stringifyClauseTaskDepend(
833  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
834  .getValue())
835  << " -> " << dependVars[i] << " : " << dependTypes[i];
836  }
837 }
838 
839 /// Verifies Depend clause
840 static LogicalResult verifyDependVarList(Operation *op,
841  std::optional<ArrayAttr> dependKinds,
842  OperandRange dependVars) {
843  if (!dependVars.empty()) {
844  if (!dependKinds || dependKinds->size() != dependVars.size())
845  return op->emitOpError() << "expected as many depend values"
846  " as depend variables";
847  } else {
848  if (dependKinds && !dependKinds->empty())
849  return op->emitOpError() << "unexpected depend values";
850  return success();
851  }
852 
853  return success();
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // Parser, printer and verifier for Synchronization Hint (2.17.12)
858 //===----------------------------------------------------------------------===//
859 
860 /// Parses a Synchronization Hint clause. The value of hint is an integer
861 /// which is a combination of different hints from `omp_sync_hint_t`.
862 ///
863 /// hint-clause = `hint` `(` hint-value `)`
864 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
865  IntegerAttr &hintAttr) {
866  StringRef hintKeyword;
867  int64_t hint = 0;
868  if (succeeded(parser.parseOptionalKeyword("none"))) {
869  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
870  return success();
871  }
872  auto parseKeyword = [&]() -> ParseResult {
873  if (failed(parser.parseKeyword(&hintKeyword)))
874  return failure();
875  if (hintKeyword == "uncontended")
876  hint |= 1;
877  else if (hintKeyword == "contended")
878  hint |= 2;
879  else if (hintKeyword == "nonspeculative")
880  hint |= 4;
881  else if (hintKeyword == "speculative")
882  hint |= 8;
883  else
884  return parser.emitError(parser.getCurrentLocation())
885  << hintKeyword << " is not a valid hint";
886  return success();
887  };
888  if (parser.parseCommaSeparatedList(parseKeyword))
889  return failure();
890  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
891  return success();
892 }
893 
894 /// Prints a Synchronization Hint clause
896  IntegerAttr hintAttr) {
897  int64_t hint = hintAttr.getInt();
898 
899  if (hint == 0) {
900  p << "none";
901  return;
902  }
903 
904  // Helper function to get n-th bit from the right end of `value`
905  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
906 
907  bool uncontended = bitn(hint, 0);
908  bool contended = bitn(hint, 1);
909  bool nonspeculative = bitn(hint, 2);
910  bool speculative = bitn(hint, 3);
911 
913  if (uncontended)
914  hints.push_back("uncontended");
915  if (contended)
916  hints.push_back("contended");
917  if (nonspeculative)
918  hints.push_back("nonspeculative");
919  if (speculative)
920  hints.push_back("speculative");
921 
922  llvm::interleaveComma(hints, p);
923 }
924 
925 /// Verifies a synchronization hint clause
926 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
927 
928  // Helper function to get n-th bit from the right end of `value`
929  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
930 
931  bool uncontended = bitn(hint, 0);
932  bool contended = bitn(hint, 1);
933  bool nonspeculative = bitn(hint, 2);
934  bool speculative = bitn(hint, 3);
935 
936  if (uncontended && contended)
937  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
938  "omp_sync_hint_contended cannot be combined";
939  if (nonspeculative && speculative)
940  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
941  "omp_sync_hint_speculative cannot be combined.";
942  return success();
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // Parser, printer and verifier for Target
947 //===----------------------------------------------------------------------===//
948 
949 // Helper function to get bitwise AND of `value` and 'flag'
950 uint64_t mapTypeToBitFlag(uint64_t value,
951  llvm::omp::OpenMPOffloadMappingFlags flag) {
952  return value & llvm::to_underlying(flag);
953 }
954 
955 /// Parses a map_entries map type from a string format back into its numeric
956 /// value.
957 ///
958 /// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
959 /// `to` | `from` | `delete` `)` )+ `)` )
960 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
961  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
962  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
963 
964  // This simply verifies the correct keyword is read in, the
965  // keyword itself is stored inside of the operation
966  auto parseTypeAndMod = [&]() -> ParseResult {
967  StringRef mapTypeMod;
968  if (parser.parseKeyword(&mapTypeMod))
969  return failure();
970 
971  if (mapTypeMod == "always")
972  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
973 
974  if (mapTypeMod == "implicit")
975  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
976 
977  if (mapTypeMod == "close")
978  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
979 
980  if (mapTypeMod == "present")
981  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
982 
983  if (mapTypeMod == "to")
984  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
985 
986  if (mapTypeMod == "from")
987  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
988 
989  if (mapTypeMod == "tofrom")
990  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
991  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
992 
993  if (mapTypeMod == "delete")
994  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
995 
996  return success();
997  };
998 
999  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1000  return failure();
1001 
1002  mapType = parser.getBuilder().getIntegerAttr(
1003  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1004  llvm::to_underlying(mapTypeBits));
1005 
1006  return success();
1007 }
1008 
1009 /// Prints a map_entries map type from its numeric value out into its string
1010 /// format.
1012  IntegerAttr mapType) {
1013  uint64_t mapTypeBits = mapType.getUInt();
1014 
1015  bool emitAllocRelease = true;
1017 
1018  // handling of always, close, present placed at the beginning of the string
1019  // to aid readability
1020  if (mapTypeToBitFlag(mapTypeBits,
1021  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1022  mapTypeStrs.push_back("always");
1023  if (mapTypeToBitFlag(mapTypeBits,
1024  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1025  mapTypeStrs.push_back("implicit");
1026  if (mapTypeToBitFlag(mapTypeBits,
1027  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1028  mapTypeStrs.push_back("close");
1029  if (mapTypeToBitFlag(mapTypeBits,
1030  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1031  mapTypeStrs.push_back("present");
1032 
1033  // special handling of to/from/tofrom/delete and release/alloc, release +
1034  // alloc are the abscense of one of the other flags, whereas tofrom requires
1035  // both the to and from flag to be set.
1036  bool to = mapTypeToBitFlag(mapTypeBits,
1037  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1038  bool from = mapTypeToBitFlag(
1039  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1040  if (to && from) {
1041  emitAllocRelease = false;
1042  mapTypeStrs.push_back("tofrom");
1043  } else if (from) {
1044  emitAllocRelease = false;
1045  mapTypeStrs.push_back("from");
1046  } else if (to) {
1047  emitAllocRelease = false;
1048  mapTypeStrs.push_back("to");
1049  }
1050  if (mapTypeToBitFlag(mapTypeBits,
1051  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1052  emitAllocRelease = false;
1053  mapTypeStrs.push_back("delete");
1054  }
1055  if (emitAllocRelease)
1056  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1057 
1058  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1059  p << mapTypeStrs[i];
1060  if (i + 1 < mapTypeStrs.size()) {
1061  p << ", ";
1062  }
1063  }
1064 }
1065 
1066 static ParseResult parseMembersIndex(OpAsmParser &parser,
1067  DenseIntElementsAttr &membersIdx) {
1068  SmallVector<APInt> values;
1069  int64_t value;
1070  int64_t shape[2] = {0, 0};
1071  unsigned shapeTmp = 0;
1072  auto parseIndices = [&]() -> ParseResult {
1073  if (parser.parseInteger(value))
1074  return failure();
1075  shapeTmp++;
1076  values.push_back(APInt(32, value));
1077  return success();
1078  };
1079 
1080  do {
1081  if (failed(parser.parseLSquare()))
1082  return failure();
1083 
1084  if (parser.parseCommaSeparatedList(parseIndices))
1085  return failure();
1086 
1087  if (failed(parser.parseRSquare()))
1088  return failure();
1089 
1090  // Only set once, if any indices are not the same size
1091  // we error out in the next check as that's unsupported
1092  if (shape[1] == 0)
1093  shape[1] = shapeTmp;
1094 
1095  // Verify that the recently parsed list is equal to the
1096  // first one we parsed, they must be equal lengths to
1097  // keep the rectangular shape DenseIntElementsAttr
1098  // requires
1099  if (shapeTmp != shape[1])
1100  return failure();
1101 
1102  shapeTmp = 0;
1103  shape[0]++;
1104  } while (succeeded(parser.parseOptionalComma()));
1105 
1106  if (!values.empty()) {
1107  ShapedType valueType =
1108  VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1109  membersIdx = DenseIntElementsAttr::get(valueType, values);
1110  }
1111 
1112  return success();
1113 }
1114 
1115 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1116  DenseIntElementsAttr membersIdx) {
1117  llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1118  assert(shape.size() <= 2);
1119 
1120  if (!membersIdx)
1121  return;
1122 
1123  for (int i = 0; i < shape[0]; ++i) {
1124  p << "[";
1125  int rowOffset = i * shape[1];
1126  for (int j = 0; j < shape[1]; ++j) {
1127  p << membersIdx.getValues<int32_t>()[rowOffset + j];
1128  if ((j + 1) < shape[1])
1129  p << ",";
1130  }
1131  p << "]";
1132 
1133  if ((i + 1) < shape[0])
1134  p << ", ";
1135  }
1136 }
1137 
1138 static ParseResult
1141  SmallVectorImpl<Type> &mapTypes) {
1144  Type argType;
1145  auto parseEntries = [&]() -> ParseResult {
1146  if (parser.parseOperand(arg))
1147  return failure();
1148  if (succeeded(parser.parseOptionalArrow()) && parser.parseOperand(blockArg))
1149  return failure();
1150  mapVars.push_back(arg);
1151  return success();
1152  };
1153 
1154  auto parseTypes = [&]() -> ParseResult {
1155  if (parser.parseType(argType))
1156  return failure();
1157  mapTypes.push_back(argType);
1158  return success();
1159  };
1160 
1161  if (parser.parseCommaSeparatedList(parseEntries))
1162  return failure();
1163 
1164  if (parser.parseColon())
1165  return failure();
1166 
1167  if (parser.parseCommaSeparatedList(parseTypes))
1168  return failure();
1169 
1170  return success();
1171 }
1172 
1174  OperandRange mapVars, TypeRange mapTypes) {
1175  // Get pointer to the region if this is an omp.target, because printing map
1176  // clauses for that operation has to also show the correspondence of each
1177  // variable to the corresponding block argument.
1178  Block *entryBlock = isa<TargetOp>(op) ? &op->getRegion(0).front() : nullptr;
1179  unsigned argIndex = 0;
1180 
1181  for (const auto &mapOp : mapVars) {
1182  p << mapOp;
1183  if (entryBlock) {
1184  const auto &blockArg = entryBlock->getArgument(argIndex);
1185  p << " -> " << blockArg;
1186  }
1187  argIndex++;
1188  if (argIndex < mapVars.size())
1189  p << ", ";
1190  }
1191  p << " : ";
1192 
1193  argIndex = 0;
1194  for (const auto &mapType : mapTypes) {
1195  p << mapType;
1196  argIndex++;
1197  if (argIndex < mapVars.size())
1198  p << ", ";
1199  }
1200 }
1201 
1202 static ParseResult
1205  SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
1206  SmallVector<SymbolRefAttr> privateSymRefs;
1207  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
1208 
1209  if (failed(parser.parseCommaSeparatedList([&]() {
1210  if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
1211  parser.parseOperand(privateVars.emplace_back()) ||
1212  parser.parseArrow() ||
1213  parser.parseArgument(regionPrivateArgs.emplace_back()) ||
1214  parser.parseColonType(privateTypes.emplace_back()))
1215  return failure();
1216  return success();
1217  })))
1218  return failure();
1219 
1220  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
1221  privateSymRefs.end());
1222  privateSyms = ArrayAttr::get(parser.getContext(), privateSymAttrs);
1223 
1224  return success();
1225 }
1226 
1228  ValueRange privateVars, TypeRange privateTypes,
1229  ArrayAttr privateSyms) {
1230  // TODO: Remove target-specific logic from this function.
1231  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
1232  assert(targetOp);
1233 
1234  auto &region = op->getRegion(0);
1235  auto *argsBegin = region.front().getArguments().begin();
1236  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapVars().size(),
1237  argsBegin + targetOp.getMapVars().size() +
1238  privateTypes.size());
1239  mlir::SmallVector<bool> isByRefVec;
1240  isByRefVec.resize(privateTypes.size(), false);
1241  DenseBoolArrayAttr isByRef =
1242  DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
1243 
1244  printClauseWithRegionArgs(p, op, argsSubrange,
1245  /*clauseName=*/llvm::StringRef{}, privateVars,
1246  privateTypes, isByRef, privateSyms);
1247 }
1248 
1250  VariableCaptureKindAttr mapCaptureType) {
1251  std::string typeCapStr;
1252  llvm::raw_string_ostream typeCap(typeCapStr);
1253  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1254  typeCap << "ByRef";
1255  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1256  typeCap << "ByCopy";
1257  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1258  typeCap << "VLAType";
1259  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1260  typeCap << "This";
1261  p << typeCap.str();
1262 }
1263 
1264 static ParseResult parseCaptureType(OpAsmParser &parser,
1265  VariableCaptureKindAttr &mapCaptureType) {
1266  StringRef mapCaptureKey;
1267  if (parser.parseKeyword(&mapCaptureKey))
1268  return failure();
1269 
1270  if (mapCaptureKey == "This")
1271  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1272  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1273  if (mapCaptureKey == "ByRef")
1274  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1275  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1276  if (mapCaptureKey == "ByCopy")
1277  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1278  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1279  if (mapCaptureKey == "VLAType")
1280  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1281  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1282 
1283  return success();
1284 }
1285 
1286 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1289 
1290  for (auto mapOp : mapVars) {
1291  if (!mapOp.getDefiningOp())
1292  emitError(op->getLoc(), "missing map operation");
1293 
1294  if (auto mapInfoOp =
1295  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1296  if (!mapInfoOp.getMapType().has_value())
1297  emitError(op->getLoc(), "missing map type for map operand");
1298 
1299  if (!mapInfoOp.getMapCaptureType().has_value())
1300  emitError(op->getLoc(), "missing map capture type for map operand");
1301 
1302  uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1303 
1304  bool to = mapTypeToBitFlag(
1305  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1306  bool from = mapTypeToBitFlag(
1307  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1308  bool del = mapTypeToBitFlag(
1309  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1310 
1311  bool always = mapTypeToBitFlag(
1312  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1313  bool close = mapTypeToBitFlag(
1314  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1315  bool implicit = mapTypeToBitFlag(
1316  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1317 
1318  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1319  return emitError(op->getLoc(),
1320  "to, from, tofrom and alloc map types are permitted");
1321 
1322  if (isa<TargetEnterDataOp>(op) && (from || del))
1323  return emitError(op->getLoc(), "to and alloc map types are permitted");
1324 
1325  if (isa<TargetExitDataOp>(op) && to)
1326  return emitError(op->getLoc(),
1327  "from, release and delete map types are permitted");
1328 
1329  if (isa<TargetUpdateOp>(op)) {
1330  if (del) {
1331  return emitError(op->getLoc(),
1332  "at least one of to or from map types must be "
1333  "specified, other map types are not permitted");
1334  }
1335 
1336  if (!to && !from) {
1337  return emitError(op->getLoc(),
1338  "at least one of to or from map types must be "
1339  "specified, other map types are not permitted");
1340  }
1341 
1342  auto updateVar = mapInfoOp.getVarPtr();
1343 
1344  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1345  (from && updateToVars.contains(updateVar))) {
1346  return emitError(
1347  op->getLoc(),
1348  "either to or from map types can be specified, not both");
1349  }
1350 
1351  if (always || close || implicit) {
1352  return emitError(
1353  op->getLoc(),
1354  "present, mapper and iterator map type modifiers are permitted");
1355  }
1356 
1357  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1358  }
1359  } else {
1360  emitError(op->getLoc(), "map argument is not a map entry operation");
1361  }
1362  }
1363 
1364  return success();
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // TargetDataOp
1369 //===----------------------------------------------------------------------===//
1370 
1371 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1372  const TargetDataOperands &clauses) {
1373  TargetDataOp::build(builder, state, clauses.device, clauses.ifVar,
1374  clauses.mapVars, clauses.useDeviceAddrVars,
1375  clauses.useDevicePtrVars);
1376 }
1377 
1378 LogicalResult TargetDataOp::verify() {
1379  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1380  getUseDeviceAddrVars().empty()) {
1381  return ::emitError(this->getLoc(),
1382  "At least one of map, use_device_ptr_vars, or "
1383  "use_device_addr_vars operand must be present");
1384  }
1385  return verifyMapClause(*this, getMapVars());
1386 }
1387 
1388 //===----------------------------------------------------------------------===//
1389 // TargetEnterDataOp
1390 //===----------------------------------------------------------------------===//
1391 
1392 void TargetEnterDataOp::build(
1393  OpBuilder &builder, OperationState &state,
1394  const TargetEnterExitUpdateDataOperands &clauses) {
1395  MLIRContext *ctx = builder.getContext();
1396  TargetEnterDataOp::build(builder, state,
1397  makeArrayAttr(ctx, clauses.dependKinds),
1398  clauses.dependVars, clauses.device, clauses.ifVar,
1399  clauses.mapVars, clauses.nowait);
1400 }
1401 
1402 LogicalResult TargetEnterDataOp::verify() {
1403  LogicalResult verifyDependVars =
1404  verifyDependVarList(*this, getDependKinds(), getDependVars());
1405  return failed(verifyDependVars) ? verifyDependVars
1406  : verifyMapClause(*this, getMapVars());
1407 }
1408 
1409 //===----------------------------------------------------------------------===//
1410 // TargetExitDataOp
1411 //===----------------------------------------------------------------------===//
1412 
1413 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1414  const TargetEnterExitUpdateDataOperands &clauses) {
1415  MLIRContext *ctx = builder.getContext();
1416  TargetExitDataOp::build(builder, state,
1417  makeArrayAttr(ctx, clauses.dependKinds),
1418  clauses.dependVars, clauses.device, clauses.ifVar,
1419  clauses.mapVars, clauses.nowait);
1420 }
1421 
1422 LogicalResult TargetExitDataOp::verify() {
1423  LogicalResult verifyDependVars =
1424  verifyDependVarList(*this, getDependKinds(), getDependVars());
1425  return failed(verifyDependVars) ? verifyDependVars
1426  : verifyMapClause(*this, getMapVars());
1427 }
1428 
1429 //===----------------------------------------------------------------------===//
1430 // TargetUpdateOp
1431 //===----------------------------------------------------------------------===//
1432 
1433 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1434  const TargetEnterExitUpdateDataOperands &clauses) {
1435  MLIRContext *ctx = builder.getContext();
1436  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1437  clauses.dependVars, clauses.device, clauses.ifVar,
1438  clauses.mapVars, clauses.nowait);
1439 }
1440 
1441 LogicalResult TargetUpdateOp::verify() {
1442  LogicalResult verifyDependVars =
1443  verifyDependVarList(*this, getDependKinds(), getDependVars());
1444  return failed(verifyDependVars) ? verifyDependVars
1445  : verifyMapClause(*this, getMapVars());
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // TargetOp
1450 //===----------------------------------------------------------------------===//
1451 
1452 void TargetOp::build(OpBuilder &builder, OperationState &state,
1453  const TargetOperands &clauses) {
1454  MLIRContext *ctx = builder.getContext();
1455  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1456  // inReductionByref, inReductionSyms.
1457  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1458  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1459  clauses.device, clauses.hasDeviceAddrVars, clauses.ifVar,
1460  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1461  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1462  clauses.mapVars, clauses.nowait, clauses.privateVars,
1463  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
1464 }
1465 
1466 LogicalResult TargetOp::verify() {
1467  LogicalResult verifyDependVars =
1468  verifyDependVarList(*this, getDependKinds(), getDependVars());
1469  return failed(verifyDependVars) ? verifyDependVars
1470  : verifyMapClause(*this, getMapVars());
1471 }
1472 
1473 //===----------------------------------------------------------------------===//
1474 // ParallelOp
1475 //===----------------------------------------------------------------------===//
1476 
1477 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1478  ArrayRef<NamedAttribute> attributes) {
1479  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
1480  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
1481  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
1482  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
1483  /*reduction_vars=*/ValueRange(),
1484  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
1485  state.addAttributes(attributes);
1486 }
1487 
1488 void ParallelOp::build(OpBuilder &builder, OperationState &state,
1489  const ParallelOperands &clauses) {
1490  MLIRContext *ctx = builder.getContext();
1491 
1492  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1493  clauses.ifVar, clauses.numThreads, clauses.privateVars,
1494  makeArrayAttr(ctx, clauses.privateSyms),
1495  clauses.procBindKind, clauses.reductionVars,
1496  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1497  makeArrayAttr(ctx, clauses.reductionSyms));
1498 }
1499 
1500 template <typename OpType>
1501 static LogicalResult verifyPrivateVarList(OpType &op) {
1502  auto privateVars = op.getPrivateVars();
1503  auto privateSyms = op.getPrivateSymsAttr();
1504 
1505  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
1506  return success();
1507 
1508  auto numPrivateVars = privateVars.size();
1509  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
1510 
1511  if (numPrivateVars != numPrivateSyms)
1512  return op.emitError() << "inconsistent number of private variables and "
1513  "privatizer op symbols, private vars: "
1514  << numPrivateVars
1515  << " vs. privatizer op symbols: " << numPrivateSyms;
1516 
1517  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
1518  Type varType = std::get<0>(privateVarInfo).getType();
1519  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1520  PrivateClauseOp privatizerOp =
1521  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
1522 
1523  if (privatizerOp == nullptr)
1524  return op.emitError() << "failed to lookup privatizer op with symbol: '"
1525  << privateSym << "'";
1526 
1527  Type privatizerType = privatizerOp.getType();
1528 
1529  if (varType != privatizerType)
1530  return op.emitError()
1531  << "type mismatch between a "
1532  << (privatizerOp.getDataSharingType() ==
1533  DataSharingClauseType::Private
1534  ? "private"
1535  : "firstprivate")
1536  << " variable and its privatizer op, var type: " << varType
1537  << " vs. privatizer op type: " << privatizerType;
1538  }
1539 
1540  return success();
1541 }
1542 
1543 LogicalResult ParallelOp::verify() {
1544  auto distributeChildOps = getOps<DistributeOp>();
1545  if (!distributeChildOps.empty()) {
1546  if (!isComposite())
1547  return emitError()
1548  << "'omp.composite' attribute missing from composite operation";
1549 
1550  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
1551  Operation &distributeOp = **distributeChildOps.begin();
1552  for (Operation &childOp : getOps()) {
1553  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
1554  continue;
1555 
1556  if (!childOp.hasTrait<OpTrait::IsTerminator>())
1557  return emitError() << "unexpected OpenMP operation inside of composite "
1558  "'omp.parallel'";
1559  }
1560  } else if (isComposite()) {
1561  return emitError()
1562  << "'omp.composite' attribute present in non-composite operation";
1563  }
1564 
1565  if (getAllocateVars().size() != getAllocatorVars().size())
1566  return emitError(
1567  "expected equal sizes for allocate and allocator variables");
1568 
1569  if (failed(verifyPrivateVarList(*this)))
1570  return failure();
1571 
1572  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1573  getReductionByref());
1574 }
1575 
1576 //===----------------------------------------------------------------------===//
1577 // TeamsOp
1578 //===----------------------------------------------------------------------===//
1579 
1581  while ((op = op->getParentOp()))
1582  if (isa<OpenMPDialect>(op->getDialect()))
1583  return false;
1584  return true;
1585 }
1586 
1587 void TeamsOp::build(OpBuilder &builder, OperationState &state,
1588  const TeamsOperands &clauses) {
1589  MLIRContext *ctx = builder.getContext();
1590  // TODO Store clauses in op: privateVars, privateSyms.
1591  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1592  clauses.ifVar, clauses.numTeamsLower, clauses.numTeamsUpper,
1593  /*private_vars=*/{},
1594  /*private_syms=*/nullptr, clauses.reductionVars,
1595  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1596  makeArrayAttr(ctx, clauses.reductionSyms),
1597  clauses.threadLimit);
1598 }
1599 
1600 LogicalResult TeamsOp::verify() {
1601  // Check parent region
1602  // TODO If nested inside of a target region, also check that it does not
1603  // contain any statements, declarations or directives other than this
1604  // omp.teams construct. The issue is how to support the initialization of
1605  // this operation's own arguments (allow SSA values across omp.target?).
1606  Operation *op = getOperation();
1607  if (!isa<TargetOp>(op->getParentOp()) &&
1609  return emitError("expected to be nested inside of omp.target or not nested "
1610  "in any OpenMP dialect operations");
1611 
1612  // Check for num_teams clause restrictions
1613  if (auto numTeamsLowerBound = getNumTeamsLower()) {
1614  auto numTeamsUpperBound = getNumTeamsUpper();
1615  if (!numTeamsUpperBound)
1616  return emitError("expected num_teams upper bound to be defined if the "
1617  "lower bound is defined");
1618  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1619  return emitError(
1620  "expected num_teams upper bound and lower bound to be the same type");
1621  }
1622 
1623  // Check for allocate clause restrictions
1624  if (getAllocateVars().size() != getAllocatorVars().size())
1625  return emitError(
1626  "expected equal sizes for allocate and allocator variables");
1627 
1628  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1629  getReductionByref());
1630 }
1631 
1632 //===----------------------------------------------------------------------===//
1633 // SectionsOp
1634 //===----------------------------------------------------------------------===//
1635 
1636 void SectionsOp::build(OpBuilder &builder, OperationState &state,
1637  const SectionsOperands &clauses) {
1638  MLIRContext *ctx = builder.getContext();
1639  // TODO Store clauses in op: privateVars, privateSyms.
1640  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1641  clauses.nowait, /*private_vars=*/{},
1642  /*private_syms=*/nullptr, clauses.reductionVars,
1643  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1644  makeArrayAttr(ctx, clauses.reductionSyms));
1645 }
1646 
1647 LogicalResult SectionsOp::verify() {
1648  if (getAllocateVars().size() != getAllocatorVars().size())
1649  return emitError(
1650  "expected equal sizes for allocate and allocator variables");
1651 
1652  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1653  getReductionByref());
1654 }
1655 
1656 LogicalResult SectionsOp::verifyRegions() {
1657  for (auto &inst : *getRegion().begin()) {
1658  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1659  return emitOpError()
1660  << "expected omp.section op or terminator op inside region";
1661  }
1662  }
1663 
1664  return success();
1665 }
1666 
1667 //===----------------------------------------------------------------------===//
1668 // SingleOp
1669 //===----------------------------------------------------------------------===//
1670 
1671 void SingleOp::build(OpBuilder &builder, OperationState &state,
1672  const SingleOperands &clauses) {
1673  MLIRContext *ctx = builder.getContext();
1674  // TODO Store clauses in op: privateVars, privateSyms.
1675  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1676  clauses.copyprivateVars,
1677  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
1678  /*private_vars=*/{}, /*private_syms=*/nullptr);
1679 }
1680 
1681 LogicalResult SingleOp::verify() {
1682  // Check for allocate clause restrictions
1683  if (getAllocateVars().size() != getAllocatorVars().size())
1684  return emitError(
1685  "expected equal sizes for allocate and allocator variables");
1686 
1687  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
1688  getCopyprivateSyms());
1689 }
1690 
1691 //===----------------------------------------------------------------------===//
1692 // WsloopOp
1693 //===----------------------------------------------------------------------===//
1694 
1695 ParseResult
1696 parseWsloop(OpAsmParser &parser, Region &region,
1698  SmallVectorImpl<Type> &reductionTypes,
1699  DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols) {
1700  // Parse an optional reduction clause
1702  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
1703  if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
1704  reductionTypes, reductionByRef,
1705  reductionSymbols, privates)))
1706  return failure();
1707  }
1708  return parser.parseRegion(region, privates);
1709 }
1710 
1712  ValueRange reductionOperands, TypeRange reductionTypes,
1713  DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols) {
1714  if (reductionSymbols) {
1715  auto reductionArgs = region.front().getArguments();
1716  printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1717  reductionOperands, reductionTypes, isByRef,
1718  reductionSymbols);
1719  }
1720  p.printRegion(region, /*printEntryBlockArgs=*/false);
1721 }
1722 
1723 static LogicalResult verifyLoopWrapperInterface(Operation *op) {
1724  if (op->getNumRegions() != 1)
1725  return op->emitOpError() << "loop wrapper contains multiple regions";
1726 
1727  Region &region = op->getRegion(0);
1728  if (!region.hasOneBlock())
1729  return op->emitOpError() << "loop wrapper contains multiple blocks";
1730 
1731  if (::llvm::range_size(region.getOps()) != 2)
1732  return op->emitOpError()
1733  << "loop wrapper does not contain exactly two nested ops";
1734 
1735  Operation &firstOp = *region.op_begin();
1736  Operation &secondOp = *(std::next(region.op_begin()));
1737 
1738  if (!secondOp.hasTrait<OpTrait::IsTerminator>())
1739  return op->emitOpError()
1740  << "second nested op in loop wrapper is not a terminator";
1741 
1742  if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1743  return op->emitOpError() << "first nested op in loop wrapper is not "
1744  "another loop wrapper or `omp.loop_nest`";
1745 
1746  return success();
1747 }
1748 
1749 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1750  ArrayRef<NamedAttribute> attributes) {
1751  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1752  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1753  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
1754  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
1755  /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
1756  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
1757  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
1758  /*schedule_simd=*/false);
1759  state.addAttributes(attributes);
1760 }
1761 
1762 void WsloopOp::build(OpBuilder &builder, OperationState &state,
1763  const WsloopOperands &clauses) {
1764  MLIRContext *ctx = builder.getContext();
1765  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
1766  // privateSyms.
1767  WsloopOp::build(
1768  builder, state,
1769  /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
1770  clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
1771  clauses.ordered, /*private_vars=*/{}, /*private_syms=*/nullptr,
1772  clauses.reductionVars,
1773  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
1774  makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
1775  clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
1776 }
1777 
1778 LogicalResult WsloopOp::verify() {
1779  if (verifyLoopWrapperInterface(*this).failed())
1780  return failure();
1781 
1782  bool isCompositeChildLeaf =
1783  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1784 
1785  if (LoopWrapperInterface nested = getNestedWrapper()) {
1786  if (!isComposite())
1787  return emitError()
1788  << "'omp.composite' attribute missing from composite wrapper";
1789 
1790  // Check for the allowed leaf constructs that may appear in a composite
1791  // construct directly after DO/FOR.
1792  if (!isa<SimdOp>(nested))
1793  return emitError() << "only supported nested wrapper is 'omp.simd'";
1794 
1795  } else if (isComposite() && !isCompositeChildLeaf) {
1796  return emitError()
1797  << "'omp.composite' attribute present in non-composite wrapper";
1798  } else if (!isComposite() && isCompositeChildLeaf) {
1799  return emitError()
1800  << "'omp.composite' attribute missing from composite wrapper";
1801  }
1802 
1803  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
1804  getReductionByref());
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // Simd construct [2.9.3.1]
1809 //===----------------------------------------------------------------------===//
1810 
1811 void SimdOp::build(OpBuilder &builder, OperationState &state,
1812  const SimdOperands &clauses) {
1813  MLIRContext *ctx = builder.getContext();
1814  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
1815  // privateSyms, reductionVars, reductionByref, reductionSyms.
1816  SimdOp::build(builder, state, clauses.alignedVars,
1817  makeArrayAttr(ctx, clauses.alignments), clauses.ifVar,
1818  /*linear_vars=*/{}, /*linear_step_vars=*/{},
1819  clauses.nontemporalVars, clauses.order, clauses.orderMod,
1820  /*private_vars=*/{}, /*private_syms=*/nullptr,
1821  /*reduction_vars=*/{}, /*reduction_byref=*/nullptr,
1822  /*reduction_syms=*/nullptr, clauses.safelen, clauses.simdlen);
1823 }
1824 
1825 LogicalResult SimdOp::verify() {
1826  if (getSimdlen().has_value() && getSafelen().has_value() &&
1827  getSimdlen().value() > getSafelen().value())
1828  return emitOpError()
1829  << "simdlen clause and safelen clause are both present, but the "
1830  "simdlen value is not less than or equal to safelen value";
1831 
1832  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
1833  return failure();
1834 
1835  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
1836  return failure();
1837 
1838  if (verifyLoopWrapperInterface(*this).failed())
1839  return failure();
1840 
1841  if (getNestedWrapper())
1842  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
1843 
1844  bool isCompositeChildLeaf =
1845  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1846 
1847  if (!isComposite() && isCompositeChildLeaf)
1848  return emitError()
1849  << "'omp.composite' attribute missing from composite wrapper";
1850 
1851  if (isComposite() && !isCompositeChildLeaf)
1852  return emitError()
1853  << "'omp.composite' attribute present in non-composite wrapper";
1854 
1855  return success();
1856 }
1857 
1858 //===----------------------------------------------------------------------===//
1859 // Distribute construct [2.9.4.1]
1860 //===----------------------------------------------------------------------===//
1861 
1862 void DistributeOp::build(OpBuilder &builder, OperationState &state,
1863  const DistributeOperands &clauses) {
1864  // TODO Store clauses in op: privateVars, privateSyms.
1865  DistributeOp::build(
1866  builder, state, clauses.allocateVars, clauses.allocatorVars,
1867  clauses.distScheduleStatic, clauses.distScheduleChunkSize, clauses.order,
1868  clauses.orderMod, /*private_vars=*/{}, /*private_syms=*/nullptr);
1869 }
1870 
1871 LogicalResult DistributeOp::verify() {
1872  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
1873  return emitOpError() << "chunk size set without "
1874  "dist_schedule_static being present";
1875 
1876  if (getAllocateVars().size() != getAllocatorVars().size())
1877  return emitError(
1878  "expected equal sizes for allocate and allocator variables");
1879 
1880  if (verifyLoopWrapperInterface(*this).failed())
1881  return failure();
1882 
1883  if (LoopWrapperInterface nested = getNestedWrapper()) {
1884  if (!isComposite())
1885  return emitError()
1886  << "'omp.composite' attribute missing from composite wrapper";
1887  // Check for the allowed leaf constructs that may appear in a composite
1888  // construct directly after DISTRIBUTE.
1889  if (isa<WsloopOp>(nested)) {
1890  if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
1891  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
1892  "when 'omp.parallel' is the direct parent";
1893  } else if (!isa<SimdOp>(nested))
1894  return emitError() << "only supported nested wrappers are 'omp.simd' and "
1895  "'omp.wsloop'";
1896  } else if (isComposite()) {
1897  return emitError()
1898  << "'omp.composite' attribute present in non-composite wrapper";
1899  }
1900 
1901  return success();
1902 }
1903 
1904 //===----------------------------------------------------------------------===//
1905 // DeclareReductionOp
1906 //===----------------------------------------------------------------------===//
1907 
1908 LogicalResult DeclareReductionOp::verifyRegions() {
1909  if (!getAllocRegion().empty()) {
1910  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
1911  if (yieldOp.getResults().size() != 1 ||
1912  yieldOp.getResults().getTypes()[0] != getType())
1913  return emitOpError() << "expects alloc region to yield a value "
1914  "of the reduction type";
1915  }
1916  }
1917 
1918  if (getInitializerRegion().empty())
1919  return emitOpError() << "expects non-empty initializer region";
1920  Block &initializerEntryBlock = getInitializerRegion().front();
1921 
1922  if (initializerEntryBlock.getNumArguments() == 1) {
1923  if (!getAllocRegion().empty())
1924  return emitOpError() << "expects two arguments to the initializer region "
1925  "when an allocation region is used";
1926  } else if (initializerEntryBlock.getNumArguments() == 2) {
1927  if (getAllocRegion().empty())
1928  return emitOpError() << "expects one argument to the initializer region "
1929  "when no allocation region is used";
1930  } else {
1931  return emitOpError()
1932  << "expects one or two arguments to the initializer region";
1933  }
1934 
1935  for (mlir::Value arg : initializerEntryBlock.getArguments())
1936  if (arg.getType() != getType())
1937  return emitOpError() << "expects initializer region argument to match "
1938  "the reduction type";
1939 
1940  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1941  if (yieldOp.getResults().size() != 1 ||
1942  yieldOp.getResults().getTypes()[0] != getType())
1943  return emitOpError() << "expects initializer region to yield a value "
1944  "of the reduction type";
1945  }
1946 
1947  if (getReductionRegion().empty())
1948  return emitOpError() << "expects non-empty reduction region";
1949  Block &reductionEntryBlock = getReductionRegion().front();
1950  if (reductionEntryBlock.getNumArguments() != 2 ||
1951  reductionEntryBlock.getArgumentTypes()[0] !=
1952  reductionEntryBlock.getArgumentTypes()[1] ||
1953  reductionEntryBlock.getArgumentTypes()[0] != getType())
1954  return emitOpError() << "expects reduction region with two arguments of "
1955  "the reduction type";
1956  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1957  if (yieldOp.getResults().size() != 1 ||
1958  yieldOp.getResults().getTypes()[0] != getType())
1959  return emitOpError() << "expects reduction region to yield a value "
1960  "of the reduction type";
1961  }
1962 
1963  if (!getAtomicReductionRegion().empty()) {
1964  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1965  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1966  atomicReductionEntryBlock.getArgumentTypes()[0] !=
1967  atomicReductionEntryBlock.getArgumentTypes()[1])
1968  return emitOpError() << "expects atomic reduction region with two "
1969  "arguments of the same type";
1970  auto ptrType = llvm::dyn_cast<PointerLikeType>(
1971  atomicReductionEntryBlock.getArgumentTypes()[0]);
1972  if (!ptrType ||
1973  (ptrType.getElementType() && ptrType.getElementType() != getType()))
1974  return emitOpError() << "expects atomic reduction region arguments to "
1975  "be accumulators containing the reduction type";
1976  }
1977 
1978  if (getCleanupRegion().empty())
1979  return success();
1980  Block &cleanupEntryBlock = getCleanupRegion().front();
1981  if (cleanupEntryBlock.getNumArguments() != 1 ||
1982  cleanupEntryBlock.getArgument(0).getType() != getType())
1983  return emitOpError() << "expects cleanup region with one argument "
1984  "of the reduction type";
1985 
1986  return success();
1987 }
1988 
1989 //===----------------------------------------------------------------------===//
1990 // TaskOp
1991 //===----------------------------------------------------------------------===//
1992 
1993 void TaskOp::build(OpBuilder &builder, OperationState &state,
1994  const TaskOperands &clauses) {
1995  MLIRContext *ctx = builder.getContext();
1996  // TODO Store clauses in op: privateVars, privateSyms.
1997  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1998  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1999  clauses.final, clauses.ifVar, clauses.inReductionVars,
2000  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2001  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2002  clauses.priority, /*private_vars=*/{}, /*private_syms=*/nullptr,
2003  clauses.untied);
2004 }
2005 
2006 LogicalResult TaskOp::verify() {
2007  LogicalResult verifyDependVars =
2008  verifyDependVarList(*this, getDependKinds(), getDependVars());
2009  return failed(verifyDependVars)
2010  ? verifyDependVars
2011  : verifyReductionVarList(*this, getInReductionSyms(),
2012  getInReductionVars(),
2013  getInReductionByref());
2014 }
2015 
2016 //===----------------------------------------------------------------------===//
2017 // TaskgroupOp
2018 //===----------------------------------------------------------------------===//
2019 
2020 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2021  const TaskgroupOperands &clauses) {
2022  MLIRContext *ctx = builder.getContext();
2023  TaskgroupOp::build(builder, state, clauses.allocateVars,
2024  clauses.allocatorVars, clauses.taskReductionVars,
2025  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2026  makeArrayAttr(ctx, clauses.taskReductionSyms));
2027 }
2028 
2029 LogicalResult TaskgroupOp::verify() {
2030  return verifyReductionVarList(*this, getTaskReductionSyms(),
2031  getTaskReductionVars(),
2032  getTaskReductionByref());
2033 }
2034 
2035 //===----------------------------------------------------------------------===//
2036 // TaskloopOp
2037 //===----------------------------------------------------------------------===//
2038 
2039 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2040  const TaskloopOperands &clauses) {
2041  MLIRContext *ctx = builder.getContext();
2042  // TODO Store clauses in op: privateVars, privateSyms.
2043  TaskloopOp::build(
2044  builder, state, clauses.allocateVars, clauses.allocatorVars,
2045  clauses.final, clauses.grainsize, clauses.ifVar, clauses.inReductionVars,
2046  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2047  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2048  clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2049  /*private_syms=*/nullptr, clauses.reductionVars,
2050  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2051  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2052 }
2053 
2054 SmallVector<Value> TaskloopOp::getAllReductionVars() {
2055  SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
2056  getInReductionVars().end());
2057  allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
2058  getReductionVars().end());
2059  return allReductionNvars;
2060 }
2061 
2062 LogicalResult TaskloopOp::verify() {
2063  if (getAllocateVars().size() != getAllocatorVars().size())
2064  return emitError(
2065  "expected equal sizes for allocate and allocator variables");
2066  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2067  getReductionVars(), getReductionByref())) ||
2068  failed(verifyReductionVarList(*this, getInReductionSyms(),
2069  getInReductionVars(),
2070  getInReductionByref())))
2071  return failure();
2072 
2073  if (!getReductionVars().empty() && getNogroup())
2074  return emitError("if a reduction clause is present on the taskloop "
2075  "directive, the nogroup clause must not be specified");
2076  for (auto var : getReductionVars()) {
2077  if (llvm::is_contained(getInReductionVars(), var))
2078  return emitError("the same list item cannot appear in both a reduction "
2079  "and an in_reduction clause");
2080  }
2081 
2082  if (getGrainsize() && getNumTasks()) {
2083  return emitError(
2084  "the grainsize clause and num_tasks clause are mutually exclusive and "
2085  "may not appear on the same taskloop directive");
2086  }
2087 
2088  if (verifyLoopWrapperInterface(*this).failed())
2089  return failure();
2090 
2091  if (LoopWrapperInterface nested = getNestedWrapper()) {
2092  if (!isComposite())
2093  return emitError()
2094  << "'omp.composite' attribute missing from composite wrapper";
2095 
2096  // Check for the allowed leaf constructs that may appear in a composite
2097  // construct directly after TASKLOOP.
2098  if (!isa<SimdOp>(nested))
2099  return emitError() << "only supported nested wrapper is 'omp.simd'";
2100  } else if (isComposite()) {
2101  return emitError()
2102  << "'omp.composite' attribute present in non-composite wrapper";
2103  }
2104 
2105  return success();
2106 }
2107 
2108 //===----------------------------------------------------------------------===//
2109 // LoopNestOp
2110 //===----------------------------------------------------------------------===//
2111 
2112 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2113  // Parse an opening `(` followed by induction variables followed by `)`
2116  Type loopVarType;
2118  parser.parseColonType(loopVarType) ||
2119  // Parse loop bounds.
2120  parser.parseEqual() ||
2121  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2122  parser.parseKeyword("to") ||
2123  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2124  return failure();
2125 
2126  for (auto &iv : ivs)
2127  iv.type = loopVarType;
2128 
2129  // Parse "inclusive" flag.
2130  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2131  result.addAttribute("loop_inclusive",
2132  UnitAttr::get(parser.getBuilder().getContext()));
2133 
2134  // Parse step values.
2136  if (parser.parseKeyword("step") ||
2137  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2138  return failure();
2139 
2140  // Parse the body.
2141  Region *region = result.addRegion();
2142  if (parser.parseRegion(*region, ivs))
2143  return failure();
2144 
2145  // Resolve operands.
2146  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2147  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2148  parser.resolveOperands(steps, loopVarType, result.operands))
2149  return failure();
2150 
2151  // Parse the optional attribute list.
2152  return parser.parseOptionalAttrDict(result.attributes);
2153 }
2154 
2156  Region &region = getRegion();
2157  auto args = region.getArguments();
2158  p << " (" << args << ") : " << args[0].getType() << " = ("
2159  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2160  if (getLoopInclusive())
2161  p << "inclusive ";
2162  p << "step (" << getLoopSteps() << ") ";
2163  p.printRegion(region, /*printEntryBlockArgs=*/false);
2164 }
2165 
2166 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2167  const LoopNestOperands &clauses) {
2168  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2169  clauses.loopUpperBounds, clauses.loopSteps,
2170  clauses.loopInclusive);
2171 }
2172 
2173 LogicalResult LoopNestOp::verify() {
2174  if (getLoopLowerBounds().empty())
2175  return emitOpError() << "must represent at least one loop";
2176 
2177  if (getLoopLowerBounds().size() != getIVs().size())
2178  return emitOpError() << "number of range arguments and IVs do not match";
2179 
2180  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2181  if (lb.getType() != iv.getType())
2182  return emitOpError()
2183  << "range argument type does not match corresponding IV type";
2184  }
2185 
2186  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2187  return emitOpError() << "expects parent op to be a loop wrapper";
2188 
2189  return success();
2190 }
2191 
2192 void LoopNestOp::gatherWrappers(
2194  Operation *parent = (*this)->getParentOp();
2195  while (auto wrapper =
2196  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2197  wrappers.push_back(wrapper);
2198  parent = parent->getParentOp();
2199  }
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // Critical construct (2.17.1)
2204 //===----------------------------------------------------------------------===//
2205 
2206 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2207  const CriticalDeclareOperands &clauses) {
2208  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2209 }
2210 
2211 LogicalResult CriticalDeclareOp::verify() {
2212  return verifySynchronizationHint(*this, getHint());
2213 }
2214 
2215 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2216  if (getNameAttr()) {
2217  SymbolRefAttr symbolRef = getNameAttr();
2218  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2219  *this, symbolRef);
2220  if (!decl) {
2221  return emitOpError() << "expected symbol reference " << symbolRef
2222  << " to point to a critical declaration";
2223  }
2224  }
2225 
2226  return success();
2227 }
2228 
2229 //===----------------------------------------------------------------------===//
2230 // Ordered construct
2231 //===----------------------------------------------------------------------===//
2232 
2233 static LogicalResult verifyOrderedParent(Operation &op) {
2234  bool hasRegion = op.getNumRegions() > 0;
2235  auto loopOp = op.getParentOfType<LoopNestOp>();
2236  if (!loopOp) {
2237  if (hasRegion)
2238  return success();
2239 
2240  // TODO: Consider if this needs to be the case only for the standalone
2241  // variant of the ordered construct.
2242  return op.emitOpError() << "must be nested inside of a loop";
2243  }
2244 
2245  Operation *wrapper = loopOp->getParentOp();
2246  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
2247  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
2248  if (!orderedAttr)
2249  return op.emitOpError() << "the enclosing worksharing-loop region must "
2250  "have an ordered clause";
2251 
2252  if (hasRegion && orderedAttr.getInt() != 0)
2253  return op.emitOpError() << "the enclosing loop's ordered clause must not "
2254  "have a parameter present";
2255 
2256  if (!hasRegion && orderedAttr.getInt() == 0)
2257  return op.emitOpError() << "the enclosing loop's ordered clause must "
2258  "have a parameter present";
2259  } else if (!isa<SimdOp>(wrapper)) {
2260  return op.emitOpError() << "must be nested inside of a worksharing, simd "
2261  "or worksharing simd loop";
2262  }
2263  return success();
2264 }
2265 
2266 void OrderedOp::build(OpBuilder &builder, OperationState &state,
2267  const OrderedOperands &clauses) {
2268  OrderedOp::build(builder, state, clauses.doacrossDependType,
2269  clauses.doacrossNumLoops, clauses.doacrossDependVars);
2270 }
2271 
2272 LogicalResult OrderedOp::verify() {
2273  if (failed(verifyOrderedParent(**this)))
2274  return failure();
2275 
2276  auto wrapper = (*this)->getParentOfType<WsloopOp>();
2277  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
2278  return emitOpError() << "number of variables in depend clause does not "
2279  << "match number of iteration variables in the "
2280  << "doacross loop";
2281 
2282  return success();
2283 }
2284 
2285 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2286  const OrderedRegionOperands &clauses) {
2287  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
2288 }
2289 
2290 LogicalResult OrderedRegionOp::verify() {
2291  // TODO: The code generation for ordered simd directive is not supported yet.
2292  if (getParLevelSimd())
2293  return failure();
2294 
2295  return verifyOrderedParent(**this);
2296 }
2297 
2298 //===----------------------------------------------------------------------===//
2299 // TaskwaitOp
2300 //===----------------------------------------------------------------------===//
2301 
2302 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2303  const TaskwaitOperands &clauses) {
2304  // TODO Store clauses in op: dependKinds, dependVars, nowait.
2305  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
2306  /*depend_vars=*/{}, /*nowait=*/nullptr);
2307 }
2308 
2309 //===----------------------------------------------------------------------===//
2310 // Verifier for AtomicReadOp
2311 //===----------------------------------------------------------------------===//
2312 
2313 LogicalResult AtomicReadOp::verify() {
2314  if (verifyCommon().failed())
2315  return mlir::failure();
2316 
2317  if (auto mo = getMemoryOrder()) {
2318  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2319  *mo == ClauseMemoryOrderKind::Release) {
2320  return emitError(
2321  "memory-order must not be acq_rel or release for atomic reads");
2322  }
2323  }
2324  return verifySynchronizationHint(*this, getHint());
2325 }
2326 
2327 //===----------------------------------------------------------------------===//
2328 // Verifier for AtomicWriteOp
2329 //===----------------------------------------------------------------------===//
2330 
2331 LogicalResult AtomicWriteOp::verify() {
2332  if (verifyCommon().failed())
2333  return mlir::failure();
2334 
2335  if (auto mo = getMemoryOrder()) {
2336  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2337  *mo == ClauseMemoryOrderKind::Acquire) {
2338  return emitError(
2339  "memory-order must not be acq_rel or acquire for atomic writes");
2340  }
2341  }
2342  return verifySynchronizationHint(*this, getHint());
2343 }
2344 
2345 //===----------------------------------------------------------------------===//
2346 // Verifier for AtomicUpdateOp
2347 //===----------------------------------------------------------------------===//
2348 
2349 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2350  PatternRewriter &rewriter) {
2351  if (op.isNoOp()) {
2352  rewriter.eraseOp(op);
2353  return success();
2354  }
2355  if (Value writeVal = op.getWriteOpVal()) {
2356  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
2357  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
2358  return success();
2359  }
2360  return failure();
2361 }
2362 
2363 LogicalResult AtomicUpdateOp::verify() {
2364  if (verifyCommon().failed())
2365  return mlir::failure();
2366 
2367  if (auto mo = getMemoryOrder()) {
2368  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2369  *mo == ClauseMemoryOrderKind::Acquire) {
2370  return emitError(
2371  "memory-order must not be acq_rel or acquire for atomic updates");
2372  }
2373  }
2374 
2375  return verifySynchronizationHint(*this, getHint());
2376 }
2377 
2378 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2379 
2380 //===----------------------------------------------------------------------===//
2381 // Verifier for AtomicCaptureOp
2382 //===----------------------------------------------------------------------===//
2383 
2384 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2385  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2386  return op;
2387  return dyn_cast<AtomicReadOp>(getSecondOp());
2388 }
2389 
2390 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2391  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2392  return op;
2393  return dyn_cast<AtomicWriteOp>(getSecondOp());
2394 }
2395 
2396 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2397  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2398  return op;
2399  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2400 }
2401 
2402 LogicalResult AtomicCaptureOp::verify() {
2403  return verifySynchronizationHint(*this, getHint());
2404 }
2405 
2406 LogicalResult AtomicCaptureOp::verifyRegions() {
2407  if (verifyRegionsCommon().failed())
2408  return mlir::failure();
2409 
2410  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
2411  return emitOpError(
2412  "operations inside capture region must not have hint clause");
2413 
2414  if (getFirstOp()->getAttr("memory_order") ||
2415  getSecondOp()->getAttr("memory_order"))
2416  return emitOpError(
2417  "operations inside capture region must not have memory_order clause");
2418  return success();
2419 }
2420 
2421 //===----------------------------------------------------------------------===//
2422 // CancelOp
2423 //===----------------------------------------------------------------------===//
2424 
2425 void CancelOp::build(OpBuilder &builder, OperationState &state,
2426  const CancelOperands &clauses) {
2427  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifVar);
2428 }
2429 
2430 LogicalResult CancelOp::verify() {
2431  ClauseCancellationConstructType cct = getCancelDirective();
2432  Operation *parentOp = (*this)->getParentOp();
2433 
2434  if (!parentOp) {
2435  return emitOpError() << "must be used within a region supporting "
2436  "cancel directive";
2437  }
2438 
2439  if ((cct == ClauseCancellationConstructType::Parallel) &&
2440  !isa<ParallelOp>(parentOp)) {
2441  return emitOpError() << "cancel parallel must appear "
2442  << "inside a parallel region";
2443  }
2444  if (cct == ClauseCancellationConstructType::Loop) {
2445  auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2446  auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2447  loopOp ? loopOp->getParentOp() : nullptr);
2448 
2449  if (!wsloopOp) {
2450  return emitOpError()
2451  << "cancel loop must appear inside a worksharing-loop region";
2452  }
2453  if (wsloopOp.getNowaitAttr()) {
2454  return emitError() << "A worksharing construct that is canceled "
2455  << "must not have a nowait clause";
2456  }
2457  if (wsloopOp.getOrderedAttr()) {
2458  return emitError() << "A worksharing construct that is canceled "
2459  << "must not have an ordered clause";
2460  }
2461 
2462  } else if (cct == ClauseCancellationConstructType::Sections) {
2463  if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2464  return emitOpError() << "cancel sections must appear "
2465  << "inside a sections region";
2466  }
2467  if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2468  cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2469  return emitError() << "A sections construct that is canceled "
2470  << "must not have a nowait clause";
2471  }
2472  }
2473  // TODO : Add more when we support taskgroup.
2474  return success();
2475 }
2476 
2477 //===----------------------------------------------------------------------===//
2478 // CancellationPointOp
2479 //===----------------------------------------------------------------------===//
2480 
2481 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
2482  const CancellationPointOperands &clauses) {
2483  CancellationPointOp::build(builder, state, clauses.cancelDirective);
2484 }
2485 
2486 LogicalResult CancellationPointOp::verify() {
2487  ClauseCancellationConstructType cct = getCancelDirective();
2488  Operation *parentOp = (*this)->getParentOp();
2489 
2490  if (!parentOp) {
2491  return emitOpError() << "must be used within a region supporting "
2492  "cancellation point directive";
2493  }
2494 
2495  if ((cct == ClauseCancellationConstructType::Parallel) &&
2496  !(isa<ParallelOp>(parentOp))) {
2497  return emitOpError() << "cancellation point parallel must appear "
2498  << "inside a parallel region";
2499  }
2500  if ((cct == ClauseCancellationConstructType::Loop) &&
2501  (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
2502  return emitOpError() << "cancellation point loop must appear "
2503  << "inside a worksharing-loop region";
2504  }
2505  if ((cct == ClauseCancellationConstructType::Sections) &&
2506  !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2507  return emitOpError() << "cancellation point sections must appear "
2508  << "inside a sections region";
2509  }
2510  // TODO : Add more when we support taskgroup.
2511  return success();
2512 }
2513 
2514 //===----------------------------------------------------------------------===//
2515 // MapBoundsOp
2516 //===----------------------------------------------------------------------===//
2517 
2518 LogicalResult MapBoundsOp::verify() {
2519  auto extent = getExtent();
2520  auto upperbound = getUpperBound();
2521  if (!extent && !upperbound)
2522  return emitError("expected extent or upperbound.");
2523  return success();
2524 }
2525 
2526 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2527  TypeRange /*result_types*/, StringAttr symName,
2528  TypeAttr type) {
2529  PrivateClauseOp::build(
2530  odsBuilder, odsState, symName, type,
2532  DataSharingClauseType::Private));
2533 }
2534 
2535 LogicalResult PrivateClauseOp::verify() {
2536  Type symType = getType();
2537 
2538  auto verifyTerminator = [&](Operation *terminator,
2539  bool yieldsValue) -> LogicalResult {
2540  if (!terminator->getBlock()->getSuccessors().empty())
2541  return success();
2542 
2543  if (!llvm::isa<YieldOp>(terminator))
2544  return mlir::emitError(terminator->getLoc())
2545  << "expected exit block terminator to be an `omp.yield` op.";
2546 
2547  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2548  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2549 
2550  if (!yieldsValue) {
2551  if (yieldedTypes.empty())
2552  return success();
2553 
2554  return mlir::emitError(terminator->getLoc())
2555  << "Did not expect any values to be yielded.";
2556  }
2557 
2558  if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2559  return success();
2560 
2561  auto error = mlir::emitError(yieldOp.getLoc())
2562  << "Invalid yielded value. Expected type: " << symType
2563  << ", got: ";
2564 
2565  if (yieldedTypes.empty())
2566  error << "None";
2567  else
2568  error << yieldedTypes;
2569 
2570  return error;
2571  };
2572 
2573  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
2574  StringRef regionName,
2575  bool yieldsValue) -> LogicalResult {
2576  assert(!region.empty());
2577 
2578  if (region.getNumArguments() != expectedNumArgs)
2579  return mlir::emitError(region.getLoc())
2580  << "`" << regionName << "`: "
2581  << "expected " << expectedNumArgs
2582  << " region arguments, got: " << region.getNumArguments();
2583 
2584  for (Block &block : region) {
2585  // MLIR will verify the absence of the terminator for us.
2586  if (!block.mightHaveTerminator())
2587  continue;
2588 
2589  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
2590  return failure();
2591  }
2592 
2593  return success();
2594  };
2595 
2596  if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc",
2597  /*yieldsValue=*/true)))
2598  return failure();
2599 
2600  DataSharingClauseType dsType = getDataSharingType();
2601 
2602  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2603  return emitError("`private` clauses require only an `alloc` region.");
2604 
2605  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2606  return emitError(
2607  "`firstprivate` clauses require both `alloc` and `copy` regions.");
2608 
2609  if (dsType == DataSharingClauseType::FirstPrivate &&
2610  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
2611  /*yieldsValue=*/true)))
2612  return failure();
2613 
2614  if (!getDeallocRegion().empty() &&
2615  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
2616  /*yieldsValue=*/false)))
2617  return failure();
2618 
2619  return success();
2620 }
2621 
2622 //===----------------------------------------------------------------------===//
2623 // Spec 5.2: Masked construct (10.5)
2624 //===----------------------------------------------------------------------===//
2625 
2626 void MaskedOp::build(OpBuilder &builder, OperationState &state,
2627  const MaskedOperands &clauses) {
2628  MaskedOp::build(builder, state, clauses.filteredThreadId);
2629 }
2630 
2631 #define GET_ATTRDEF_CLASSES
2632 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2633 
2634 #define GET_OP_CLASSES
2635 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2636 
2637 #define GET_TYPEDEF_CLASSES
2638 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
reduction-entry-list ::= reduction-entry | reduction-entry-list , reduction-entry reduction-entry ::=...
static ParseResult parsePrivateList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printMapEntries(OpAsmPrinter &p, Operation *op, OperandRange mapVars, TypeRange mapTypes)
static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, std::optional< DenseBoolArrayAttr > reductionByref, std::optional< ArrayAttr > reductionSyms)
Print Reduction clause.
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static ParseResult parseParallelRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
void printWsloop(OpAsmPrinter &p, Operation *op, Region &region, ValueRange reductionOperands, TypeRange reductionTypes, DenseBoolArrayAttr isByRef, ArrayAttr reductionSymbols)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static LogicalResult verifyLoopWrapperInterface(Operation *op)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
ParseResult parseWsloop(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionOperands, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByRef, ArrayAttr &reductionSymbols)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, DenseIntElementsAttr membersIdx)
static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, ValueRange argsSubrange, StringRef clauseName, ValueRange operands, TypeRange types, DenseBoolArrayAttr byref, ArrayAttr symbols)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printPrivateList(OpAsmPrinter &p, Operation *op, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseMembersIndex(OpAsmParser &parser, DenseIntElementsAttr &membersIdx)
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
Definition: Block.h:31
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
SuccessorRange getSuccessors()
Definition: Block.h:265
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:246
IntegerType getI64Type()
Definition: Builders.cpp:93
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:211
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.