MLIR  21.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/STLForwardCompat.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Frontend/OpenMP/OMPConstants.h"
34 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35 #include <cstddef>
36 #include <iterator>
37 #include <optional>
38 #include <variant>
39 
40 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::omp;
47 
48 static ArrayAttr makeArrayAttr(MLIRContext *context,
50  return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51 }
52 
53 static DenseBoolArrayAttr
55  return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56 }
57 
58 namespace {
59 struct MemRefPointerLikeModel
60  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61  MemRefType> {
62  Type getElementType(Type pointer) const {
63  return llvm::cast<MemRefType>(pointer).getElementType();
64  }
65 };
66 
67 struct LLVMPointerPointerLikeModel
68  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69  LLVM::LLVMPointerType> {
70  Type getElementType(Type pointer) const { return Type(); }
71 };
72 } // namespace
73 
74 void OpenMPDialect::initialize() {
75  addOperations<
76 #define GET_OP_LIST
77 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78  >();
79  addAttributes<
80 #define GET_ATTRDEF_LIST
81 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82  >();
83  addTypes<
84 #define GET_TYPEDEF_LIST
85 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86  >();
87 
88  declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89 
90  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92  *getContext());
93 
94  // Attach default offload module interface to module op to access
95  // offload functionality through
96  mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97  *getContext());
98 
99  // Attach default declare target interfaces to operations which can be marked
100  // as declare target (Global Operations and Functions/Subroutines in dialects
101  // that Fortran (or other languages that lower to MLIR) translates too
102  mlir::LLVM::GlobalOp::attachInterface<
104  *getContext());
105  mlir::LLVM::LLVMFuncOp::attachInterface<
107  *getContext());
108  mlir::func::FuncOp::attachInterface<
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Parser and printer for Allocate Clause
114 //===----------------------------------------------------------------------===//
115 
116 /// Parse an allocate clause with allocators and a list of operands with types.
117 ///
118 /// allocate-operand-list :: = allocate-operand |
119 /// allocator-operand `,` allocate-operand-list
120 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121 /// ssa-id-and-type ::= ssa-id `:` type
122 static ParseResult parseAllocateAndAllocator(
123  OpAsmParser &parser,
125  SmallVectorImpl<Type> &allocateTypes,
127  SmallVectorImpl<Type> &allocatorTypes) {
128 
129  return parser.parseCommaSeparatedList([&]() {
131  Type type;
132  if (parser.parseOperand(operand) || parser.parseColonType(type))
133  return failure();
134  allocatorVars.push_back(operand);
135  allocatorTypes.push_back(type);
136  if (parser.parseArrow())
137  return failure();
138  if (parser.parseOperand(operand) || parser.parseColonType(type))
139  return failure();
140 
141  allocateVars.push_back(operand);
142  allocateTypes.push_back(type);
143  return success();
144  });
145 }
146 
147 /// Print allocate clause
149  OperandRange allocateVars,
150  TypeRange allocateTypes,
151  OperandRange allocatorVars,
152  TypeRange allocatorTypes) {
153  for (unsigned i = 0; i < allocateVars.size(); ++i) {
154  std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155  p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156  p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157  }
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parser and printer for a clause attribute (StringEnumAttr)
162 //===----------------------------------------------------------------------===//
163 
164 template <typename ClauseAttr>
165 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167  StringRef enumStr;
168  SMLoc loc = parser.getCurrentLocation();
169  if (parser.parseKeyword(&enumStr))
170  return failure();
171  if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172  attr = ClauseAttr::get(parser.getContext(), *enumValue);
173  return success();
174  }
175  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
176 }
177 
178 template <typename ClauseAttr>
179 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180  p << stringifyEnum(attr.getValue());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Parser and printer for Linear Clause
185 //===----------------------------------------------------------------------===//
186 
187 /// linear ::= `linear` `(` linear-list `)`
188 /// linear-list := linear-val | linear-val linear-list
189 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
190 static ParseResult parseLinearClause(
191  OpAsmParser &parser,
193  SmallVectorImpl<Type> &linearTypes,
195  return parser.parseCommaSeparatedList([&]() {
197  Type type;
199  if (parser.parseOperand(var) || parser.parseEqual() ||
200  parser.parseOperand(stepVar) || parser.parseColonType(type))
201  return failure();
202 
203  linearVars.push_back(var);
204  linearTypes.push_back(type);
205  linearStepVars.push_back(stepVar);
206  return success();
207  });
208 }
209 
210 /// Print Linear Clause
212  ValueRange linearVars, TypeRange linearTypes,
213  ValueRange linearStepVars) {
214  size_t linearVarsSize = linearVars.size();
215  for (unsigned i = 0; i < linearVarsSize; ++i) {
216  std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217  p << linearVars[i];
218  if (linearStepVars.size() > i)
219  p << " = " << linearStepVars[i];
220  p << " : " << linearVars[i].getType() << separator;
221  }
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Verifier for Nontemporal Clause
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verifyNontemporalClause(Operation *op,
229  OperandRange nontemporalVars) {
230 
231  // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232  DenseSet<Value> nontemporalItems;
233  for (const auto &it : nontemporalVars)
234  if (!nontemporalItems.insert(it).second)
235  return op->emitOpError() << "nontemporal variable used more than once";
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Parser, verifier and printer for Aligned Clause
242 //===----------------------------------------------------------------------===//
243 static LogicalResult verifyAlignedClause(Operation *op,
244  std::optional<ArrayAttr> alignments,
245  OperandRange alignedVars) {
246  // Check if number of alignment values equals to number of aligned variables
247  if (!alignedVars.empty()) {
248  if (!alignments || alignments->size() != alignedVars.size())
249  return op->emitOpError()
250  << "expected as many alignment values as aligned variables";
251  } else {
252  if (alignments)
253  return op->emitOpError() << "unexpected alignment values attribute";
254  return success();
255  }
256 
257  // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258  DenseSet<Value> alignedItems;
259  for (auto it : alignedVars)
260  if (!alignedItems.insert(it).second)
261  return op->emitOpError() << "aligned variable used more than once";
262 
263  if (!alignments)
264  return success();
265 
266  // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267  for (unsigned i = 0; i < (*alignments).size(); ++i) {
268  if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269  if (intAttr.getValue().sle(0))
270  return op->emitOpError() << "alignment should be greater than 0";
271  } else {
272  return op->emitOpError() << "expected integer alignment";
273  }
274  }
275 
276  return success();
277 }
278 
279 /// aligned ::= `aligned` `(` aligned-list `)`
280 /// aligned-list := aligned-val | aligned-val aligned-list
281 /// aligned-val := ssa-id-and-type `->` alignment
282 static ParseResult
285  SmallVectorImpl<Type> &alignedTypes,
286  ArrayAttr &alignmentsAttr) {
287  SmallVector<Attribute> alignmentVec;
288  if (failed(parser.parseCommaSeparatedList([&]() {
289  if (parser.parseOperand(alignedVars.emplace_back()) ||
290  parser.parseColonType(alignedTypes.emplace_back()) ||
291  parser.parseArrow() ||
292  parser.parseAttribute(alignmentVec.emplace_back())) {
293  return failure();
294  }
295  return success();
296  })))
297  return failure();
298  SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299  alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300  return success();
301 }
302 
303 /// Print Aligned Clause
305  ValueRange alignedVars, TypeRange alignedTypes,
306  std::optional<ArrayAttr> alignments) {
307  for (unsigned i = 0; i < alignedVars.size(); ++i) {
308  if (i != 0)
309  p << ", ";
310  p << alignedVars[i] << " : " << alignedVars[i].getType();
311  p << " -> " << (*alignments)[i];
312  }
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Parser, printer and verifier for Schedule Clause
317 //===----------------------------------------------------------------------===//
318 
319 static ParseResult
321  SmallVectorImpl<SmallString<12>> &modifiers) {
322  if (modifiers.size() > 2)
323  return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
324  for (const auto &mod : modifiers) {
325  // Translate the string. If it has no value, then it was not a valid
326  // modifier!
327  auto symbol = symbolizeScheduleModifier(mod);
328  if (!symbol)
329  return parser.emitError(parser.getNameLoc())
330  << " unknown modifier type: " << mod;
331  }
332 
333  // If we have one modifier that is "simd", then stick a "none" modiifer in
334  // index 0.
335  if (modifiers.size() == 1) {
336  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337  modifiers.push_back(modifiers[0]);
338  modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339  }
340  } else if (modifiers.size() == 2) {
341  // If there are two modifier:
342  // First modifier should not be simd, second one should be simd
343  if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344  symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345  return parser.emitError(parser.getNameLoc())
346  << " incorrect modifier order";
347  }
348  return success();
349 }
350 
351 /// schedule ::= `schedule` `(` sched-list `)`
352 /// sched-list ::= sched-val | sched-val sched-list |
353 /// sched-val `,` sched-modifier
354 /// sched-val ::= sched-with-chunk | sched-wo-chunk
355 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357 /// sched-wo-chunk ::= `auto` | `runtime`
358 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
359 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
360 static ParseResult
361 parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362  ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363  std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364  Type &chunkType) {
365  StringRef keyword;
366  if (parser.parseKeyword(&keyword))
367  return failure();
368  std::optional<mlir::omp::ClauseScheduleKind> schedule =
369  symbolizeClauseScheduleKind(keyword);
370  if (!schedule)
371  return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
372 
373  scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374  switch (*schedule) {
375  case ClauseScheduleKind::Static:
376  case ClauseScheduleKind::Dynamic:
377  case ClauseScheduleKind::Guided:
378  if (succeeded(parser.parseOptionalEqual())) {
379  chunkSize = OpAsmParser::UnresolvedOperand{};
380  if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
381  return failure();
382  } else {
383  chunkSize = std::nullopt;
384  }
385  break;
386  case ClauseScheduleKind::Auto:
388  chunkSize = std::nullopt;
389  }
390 
391  // If there is a comma, we have one or more modifiers..
392  SmallVector<SmallString<12>> modifiers;
393  while (succeeded(parser.parseOptionalComma())) {
394  StringRef mod;
395  if (parser.parseKeyword(&mod))
396  return failure();
397  modifiers.push_back(mod);
398  }
399 
400  if (verifyScheduleModifiers(parser, modifiers))
401  return failure();
402 
403  if (!modifiers.empty()) {
404  SMLoc loc = parser.getCurrentLocation();
405  if (std::optional<ScheduleModifier> mod =
406  symbolizeScheduleModifier(modifiers[0])) {
407  scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408  } else {
409  return parser.emitError(loc, "invalid schedule modifier");
410  }
411  // Only SIMD attribute is allowed here!
412  if (modifiers.size() > 1) {
413  assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414  scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415  }
416  }
417 
418  return success();
419 }
420 
421 /// Print schedule clause
423  ClauseScheduleKindAttr scheduleKind,
424  ScheduleModifierAttr scheduleMod,
425  UnitAttr scheduleSimd, Value scheduleChunk,
426  Type scheduleChunkType) {
427  p << stringifyClauseScheduleKind(scheduleKind.getValue());
428  if (scheduleChunk)
429  p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430  if (scheduleMod)
431  p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432  if (scheduleSimd)
433  p << ", simd";
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Parser and printer for Order Clause
438 //===----------------------------------------------------------------------===//
439 
440 // order ::= `order` `(` [order-modifier ':'] concurrent `)`
441 // order-modifier ::= reproducible | unconstrained
442 static ParseResult parseOrderClause(OpAsmParser &parser,
443  ClauseOrderKindAttr &order,
444  OrderModifierAttr &orderMod) {
445  StringRef enumStr;
446  SMLoc loc = parser.getCurrentLocation();
447  if (parser.parseKeyword(&enumStr))
448  return failure();
449  if (std::optional<OrderModifier> enumValue =
450  symbolizeOrderModifier(enumStr)) {
451  orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452  if (parser.parseOptionalColon())
453  return failure();
454  loc = parser.getCurrentLocation();
455  if (parser.parseKeyword(&enumStr))
456  return failure();
457  }
458  if (std::optional<ClauseOrderKind> enumValue =
459  symbolizeClauseOrderKind(enumStr)) {
460  order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461  return success();
462  }
463  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
464 }
465 
467  ClauseOrderKindAttr order,
468  OrderModifierAttr orderMod) {
469  if (orderMod)
470  p << stringifyOrderModifier(orderMod.getValue()) << ":";
471  if (order)
472  p << stringifyClauseOrderKind(order.getValue());
473 }
474 
475 template <typename ClauseTypeAttr, typename ClauseType>
476 static ParseResult
477 parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478  std::optional<OpAsmParser::UnresolvedOperand> &operand,
479  Type &operandType,
480  std::optional<ClauseType> (*symbolizeClause)(StringRef),
481  StringRef clauseName) {
482  StringRef enumStr;
483  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
484  if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485  prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
486  if (parser.parseComma())
487  return failure();
488  } else {
489  return parser.emitError(parser.getCurrentLocation())
490  << "invalid " << clauseName << " modifier : '" << enumStr << "'";
491  ;
492  }
493  }
494 
496  if (succeeded(parser.parseOperand(var))) {
497  operand = var;
498  } else {
499  return parser.emitError(parser.getCurrentLocation())
500  << "expected " << clauseName << " operand";
501  }
502 
503  if (operand.has_value()) {
504  if (parser.parseColonType(operandType))
505  return failure();
506  }
507 
508  return success();
509 }
510 
511 template <typename ClauseTypeAttr, typename ClauseType>
512 static void
514  ClauseTypeAttr prescriptiveness, Value operand,
515  mlir::Type operandType,
516  StringRef (*stringifyClauseType)(ClauseType)) {
517 
518  if (prescriptiveness)
519  p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
520 
521  if (operand)
522  p << operand << ": " << operandType;
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // Parser and printer for grainsize Clause
527 //===----------------------------------------------------------------------===//
528 
529 // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530 static ParseResult
531 parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532  std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533  Type &grainsizeType) {
534  return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535  parser, grainsizeMod, grainsize, grainsizeType,
536  &symbolizeClauseGrainsizeType, "grainsize");
537 }
538 
540  ClauseGrainsizeTypeAttr grainsizeMod,
541  Value grainsize, mlir::Type grainsizeType) {
542  printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543  p, op, grainsizeMod, grainsize, grainsizeType,
544  &stringifyClauseGrainsizeType);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Parser and printer for num_tasks Clause
549 //===----------------------------------------------------------------------===//
550 
551 // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552 static ParseResult
553 parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554  std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555  Type &numTasksType) {
556  return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557  parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558  "num_tasks");
559 }
560 
562  ClauseNumTasksTypeAttr numTasksMod,
563  Value numTasks, mlir::Type numTasksType) {
564  printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565  p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Parsers for operations including clauses that define entry block arguments.
570 //===----------------------------------------------------------------------===//
571 
572 namespace {
573 struct MapParseArgs {
575  SmallVectorImpl<Type> &types;
577  SmallVectorImpl<Type> &types)
578  : vars(vars), types(types) {}
579 };
580 struct PrivateParseArgs {
583  ArrayAttr &syms;
584  DenseI64ArrayAttr *mapIndices;
586  SmallVectorImpl<Type> &types, ArrayAttr &syms,
587  DenseI64ArrayAttr *mapIndices = nullptr)
588  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
589 };
590 
591 struct ReductionParseArgs {
593  SmallVectorImpl<Type> &types;
594  DenseBoolArrayAttr &byref;
595  ArrayAttr &syms;
596  ReductionModifierAttr *modifier;
597  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
599  ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
600  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
601 };
602 
603 struct AllRegionParseArgs {
604  std::optional<MapParseArgs> hasDeviceAddrArgs;
605  std::optional<MapParseArgs> hostEvalArgs;
606  std::optional<ReductionParseArgs> inReductionArgs;
607  std::optional<MapParseArgs> mapArgs;
608  std::optional<PrivateParseArgs> privateArgs;
609  std::optional<ReductionParseArgs> reductionArgs;
610  std::optional<ReductionParseArgs> taskReductionArgs;
611  std::optional<MapParseArgs> useDeviceAddrArgs;
612  std::optional<MapParseArgs> useDevicePtrArgs;
613 };
614 } // namespace
615 
616 static ParseResult parseClauseWithRegionArgs(
617  OpAsmParser &parser,
619  SmallVectorImpl<Type> &types,
620  SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
621  ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
622  DenseBoolArrayAttr *byref = nullptr,
623  ReductionModifierAttr *modifier = nullptr) {
624  SmallVector<SymbolRefAttr> symbolVec;
625  SmallVector<int64_t> mapIndicesVec;
626  SmallVector<bool> isByRefVec;
627  unsigned regionArgOffset = regionPrivateArgs.size();
628 
629  if (parser.parseLParen())
630  return failure();
631 
632  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
633  StringRef enumStr;
634  if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
635  parser.parseComma())
636  return failure();
637  std::optional<ReductionModifier> enumValue =
638  symbolizeReductionModifier(enumStr);
639  if (!enumValue.has_value())
640  return failure();
641  *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
642  if (!*modifier)
643  return failure();
644  }
645 
646  if (parser.parseCommaSeparatedList([&]() {
647  if (byref)
648  isByRefVec.push_back(
649  parser.parseOptionalKeyword("byref").succeeded());
650 
651  if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
652  return failure();
653 
654  if (parser.parseOperand(operands.emplace_back()) ||
655  parser.parseArrow() ||
656  parser.parseArgument(regionPrivateArgs.emplace_back()))
657  return failure();
658 
659  if (mapIndices) {
660  if (parser.parseOptionalLSquare().succeeded()) {
661  if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
662  parser.parseInteger(mapIndicesVec.emplace_back()) ||
663  parser.parseRSquare())
664  return failure();
665  } else {
666  mapIndicesVec.push_back(-1);
667  }
668  }
669 
670  return success();
671  }))
672  return failure();
673 
674  if (parser.parseColon())
675  return failure();
676 
677  if (parser.parseCommaSeparatedList([&]() {
678  if (parser.parseType(types.emplace_back()))
679  return failure();
680 
681  return success();
682  }))
683  return failure();
684 
685  if (operands.size() != types.size())
686  return failure();
687 
688  if (parser.parseRParen())
689  return failure();
690 
691  auto *argsBegin = regionPrivateArgs.begin();
692  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
693  argsBegin + regionArgOffset + types.size());
694  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
695  prv.type = type;
696  }
697 
698  if (symbols) {
699  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
700  *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
701  }
702 
703  if (!mapIndicesVec.empty())
704  *mapIndices =
705  mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
706 
707  if (byref)
708  *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
709 
710  return success();
711 }
712 
713 static ParseResult parseBlockArgClause(
714  OpAsmParser &parser,
716  StringRef keyword, std::optional<MapParseArgs> mapArgs) {
717  if (succeeded(parser.parseOptionalKeyword(keyword))) {
718  if (!mapArgs)
719  return failure();
720 
721  if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
722  entryBlockArgs)))
723  return failure();
724  }
725  return success();
726 }
727 
728 static ParseResult parseBlockArgClause(
729  OpAsmParser &parser,
731  StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
732  if (succeeded(parser.parseOptionalKeyword(keyword))) {
733  if (!privateArgs)
734  return failure();
735 
736  if (failed(parseClauseWithRegionArgs(
737  parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
738  &privateArgs->syms, privateArgs->mapIndices)))
739  return failure();
740  }
741  return success();
742 }
743 
744 static ParseResult parseBlockArgClause(
745  OpAsmParser &parser,
747  StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
748  if (succeeded(parser.parseOptionalKeyword(keyword))) {
749  if (!reductionArgs)
750  return failure();
751  if (failed(parseClauseWithRegionArgs(
752  parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
753  &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
754  reductionArgs->modifier)))
755  return failure();
756  }
757  return success();
758 }
759 
760 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
761  AllRegionParseArgs args) {
763 
764  if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
765  args.hasDeviceAddrArgs)))
766  return parser.emitError(parser.getCurrentLocation())
767  << "invalid `has_device_addr` format";
768 
769  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
770  args.hostEvalArgs)))
771  return parser.emitError(parser.getCurrentLocation())
772  << "invalid `host_eval` format";
773 
774  if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
775  args.inReductionArgs)))
776  return parser.emitError(parser.getCurrentLocation())
777  << "invalid `in_reduction` format";
778 
779  if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
780  args.mapArgs)))
781  return parser.emitError(parser.getCurrentLocation())
782  << "invalid `map_entries` format";
783 
784  if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
785  args.privateArgs)))
786  return parser.emitError(parser.getCurrentLocation())
787  << "invalid `private` format";
788 
789  if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
790  args.reductionArgs)))
791  return parser.emitError(parser.getCurrentLocation())
792  << "invalid `reduction` format";
793 
794  if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
795  args.taskReductionArgs)))
796  return parser.emitError(parser.getCurrentLocation())
797  << "invalid `task_reduction` format";
798 
799  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
800  args.useDeviceAddrArgs)))
801  return parser.emitError(parser.getCurrentLocation())
802  << "invalid `use_device_addr` format";
803 
804  if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
805  args.useDevicePtrArgs)))
806  return parser.emitError(parser.getCurrentLocation())
807  << "invalid `use_device_addr` format";
808 
809  return parser.parseRegion(region, entryBlockArgs);
810 }
811 
812 // These parseXyz functions correspond to the custom<Xyz> definitions
813 // in the .td file(s).
814 static ParseResult parseTargetOpRegion(
815  OpAsmParser &parser, Region &region,
817  SmallVectorImpl<Type> &hasDeviceAddrTypes,
819  SmallVectorImpl<Type> &hostEvalTypes,
821  SmallVectorImpl<Type> &inReductionTypes,
822  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
824  SmallVectorImpl<Type> &mapTypes,
826  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
827  DenseI64ArrayAttr &privateMaps) {
828  AllRegionParseArgs args;
829  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
830  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
831  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
832  inReductionByref, inReductionSyms);
833  args.mapArgs.emplace(mapVars, mapTypes);
834  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
835  &privateMaps);
836  return parseBlockArgRegion(parser, region, args);
837 }
838 
839 static ParseResult parseInReductionPrivateRegion(
840  OpAsmParser &parser, Region &region,
842  SmallVectorImpl<Type> &inReductionTypes,
843  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
845  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
846  AllRegionParseArgs args;
847  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
848  inReductionByref, inReductionSyms);
849  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
850  return parseBlockArgRegion(parser, region, args);
851 }
852 
854  OpAsmParser &parser, Region &region,
856  SmallVectorImpl<Type> &inReductionTypes,
857  DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
859  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
860  ReductionModifierAttr &reductionMod,
862  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
863  ArrayAttr &reductionSyms) {
864  AllRegionParseArgs args;
865  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
866  inReductionByref, inReductionSyms);
867  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
868  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
869  reductionSyms, &reductionMod);
870  return parseBlockArgRegion(parser, region, args);
871 }
872 
873 static ParseResult parsePrivateRegion(
874  OpAsmParser &parser, Region &region,
876  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
877  AllRegionParseArgs args;
878  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
879  return parseBlockArgRegion(parser, region, args);
880 }
881 
882 static ParseResult parsePrivateReductionRegion(
883  OpAsmParser &parser, Region &region,
885  llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
886  ReductionModifierAttr &reductionMod,
888  SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
889  ArrayAttr &reductionSyms) {
890  AllRegionParseArgs args;
891  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
892  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
893  reductionSyms, &reductionMod);
894  return parseBlockArgRegion(parser, region, args);
895 }
896 
897 static ParseResult parseTaskReductionRegion(
898  OpAsmParser &parser, Region &region,
900  SmallVectorImpl<Type> &taskReductionTypes,
901  DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
902  AllRegionParseArgs args;
903  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
904  taskReductionByref, taskReductionSyms);
905  return parseBlockArgRegion(parser, region, args);
906 }
907 
909  OpAsmParser &parser, Region &region,
911  SmallVectorImpl<Type> &useDeviceAddrTypes,
913  SmallVectorImpl<Type> &useDevicePtrTypes) {
914  AllRegionParseArgs args;
915  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
916  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
917  return parseBlockArgRegion(parser, region, args);
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // Printers for operations including clauses that define entry block arguments.
922 //===----------------------------------------------------------------------===//
923 
924 namespace {
925 struct MapPrintArgs {
926  ValueRange vars;
927  TypeRange types;
928  MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
929 };
930 struct PrivatePrintArgs {
931  ValueRange vars;
932  TypeRange types;
933  ArrayAttr syms;
934  DenseI64ArrayAttr mapIndices;
935  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
936  DenseI64ArrayAttr mapIndices)
937  : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
938 };
939 struct ReductionPrintArgs {
940  ValueRange vars;
941  TypeRange types;
942  DenseBoolArrayAttr byref;
943  ArrayAttr syms;
944  ReductionModifierAttr modifier;
945  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
946  ArrayAttr syms, ReductionModifierAttr mod = nullptr)
947  : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
948 };
949 struct AllRegionPrintArgs {
950  std::optional<MapPrintArgs> hasDeviceAddrArgs;
951  std::optional<MapPrintArgs> hostEvalArgs;
952  std::optional<ReductionPrintArgs> inReductionArgs;
953  std::optional<MapPrintArgs> mapArgs;
954  std::optional<PrivatePrintArgs> privateArgs;
955  std::optional<ReductionPrintArgs> reductionArgs;
956  std::optional<ReductionPrintArgs> taskReductionArgs;
957  std::optional<MapPrintArgs> useDeviceAddrArgs;
958  std::optional<MapPrintArgs> useDevicePtrArgs;
959 };
960 } // namespace
961 
963  OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
964  ValueRange argsSubrange, ValueRange operands, TypeRange types,
965  ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
966  DenseBoolArrayAttr byref = nullptr,
967  ReductionModifierAttr modifier = nullptr) {
968  if (argsSubrange.empty())
969  return;
970 
971  p << clauseName << "(";
972 
973  if (modifier)
974  p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
975 
976  if (!symbols) {
977  llvm::SmallVector<Attribute> values(operands.size(), nullptr);
978  symbols = ArrayAttr::get(ctx, values);
979  }
980 
981  if (!mapIndices) {
982  llvm::SmallVector<int64_t> values(operands.size(), -1);
983  mapIndices = DenseI64ArrayAttr::get(ctx, values);
984  }
985 
986  if (!byref) {
987  mlir::SmallVector<bool> values(operands.size(), false);
988  byref = DenseBoolArrayAttr::get(ctx, values);
989  }
990 
991  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
992  mapIndices.asArrayRef(),
993  byref.asArrayRef()),
994  p, [&p](auto t) {
995  auto [op, arg, sym, map, isByRef] = t;
996  if (isByRef)
997  p << "byref ";
998  if (sym)
999  p << sym << " ";
1000 
1001  p << op << " -> " << arg;
1002 
1003  if (map != -1)
1004  p << " [map_idx=" << map << "]";
1005  });
1006  p << " : ";
1007  llvm::interleaveComma(types, p);
1008  p << ") ";
1009 }
1010 
1012  StringRef clauseName, ValueRange argsSubrange,
1013  std::optional<MapPrintArgs> mapArgs) {
1014  if (mapArgs)
1015  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1016  mapArgs->types);
1017 }
1018 
1020  StringRef clauseName, ValueRange argsSubrange,
1021  std::optional<PrivatePrintArgs> privateArgs) {
1022  if (privateArgs)
1023  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1024  privateArgs->vars, privateArgs->types,
1025  privateArgs->syms, privateArgs->mapIndices);
1026 }
1027 
1028 static void
1029 printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1030  ValueRange argsSubrange,
1031  std::optional<ReductionPrintArgs> reductionArgs) {
1032  if (reductionArgs)
1033  printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1034  reductionArgs->vars, reductionArgs->types,
1035  reductionArgs->syms, /*mapIndices=*/nullptr,
1036  reductionArgs->byref, reductionArgs->modifier);
1037 }
1038 
1039 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1040  const AllRegionPrintArgs &args) {
1041  auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1042  MLIRContext *ctx = op->getContext();
1043 
1044  printBlockArgClause(p, ctx, "has_device_addr",
1045  iface.getHasDeviceAddrBlockArgs(),
1046  args.hasDeviceAddrArgs);
1047  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1048  args.hostEvalArgs);
1049  printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1050  args.inReductionArgs);
1051  printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1052  args.mapArgs);
1053  printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1054  args.privateArgs);
1055  printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1056  args.reductionArgs);
1057  printBlockArgClause(p, ctx, "task_reduction",
1058  iface.getTaskReductionBlockArgs(),
1059  args.taskReductionArgs);
1060  printBlockArgClause(p, ctx, "use_device_addr",
1061  iface.getUseDeviceAddrBlockArgs(),
1062  args.useDeviceAddrArgs);
1063  printBlockArgClause(p, ctx, "use_device_ptr",
1064  iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1065 
1066  p.printRegion(region, /*printEntryBlockArgs=*/false);
1067 }
1068 
1069 // These parseXyz functions correspond to the custom<Xyz> definitions
1070 // in the .td file(s).
1071 static void
1073  ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1074  ValueRange hostEvalVars, TypeRange hostEvalTypes,
1075  ValueRange inReductionVars, TypeRange inReductionTypes,
1076  DenseBoolArrayAttr inReductionByref,
1077  ArrayAttr inReductionSyms, ValueRange mapVars,
1078  TypeRange mapTypes, ValueRange privateVars,
1079  TypeRange privateTypes, ArrayAttr privateSyms,
1080  DenseI64ArrayAttr privateMaps) {
1081  AllRegionPrintArgs args;
1082  args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1083  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1084  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1085  inReductionByref, inReductionSyms);
1086  args.mapArgs.emplace(mapVars, mapTypes);
1087  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
1088  printBlockArgRegion(p, op, region, args);
1089 }
1090 
1092  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1093  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1094  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1095  ArrayAttr privateSyms) {
1096  AllRegionPrintArgs args;
1097  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1098  inReductionByref, inReductionSyms);
1099  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1100  /*mapIndices=*/nullptr);
1101  printBlockArgRegion(p, op, region, args);
1102 }
1103 
1105  OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1106  TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1107  ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1108  ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
1109  ValueRange reductionVars, TypeRange reductionTypes,
1110  DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
1111  AllRegionPrintArgs args;
1112  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1113  inReductionByref, inReductionSyms);
1114  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1115  /*mapIndices=*/nullptr);
1116  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1117  reductionSyms, reductionMod);
1118  printBlockArgRegion(p, op, region, args);
1119 }
1120 
1121 static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1122  ValueRange privateVars, TypeRange privateTypes,
1123  ArrayAttr privateSyms) {
1124  AllRegionPrintArgs args;
1125  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1126  /*mapIndices=*/nullptr);
1127  printBlockArgRegion(p, op, region, args);
1128 }
1129 
1131  OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1132  TypeRange privateTypes, ArrayAttr privateSyms,
1133  ReductionModifierAttr reductionMod, ValueRange reductionVars,
1134  TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1135  ArrayAttr reductionSyms) {
1136  AllRegionPrintArgs args;
1137  args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1138  /*mapIndices=*/nullptr);
1139  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1140  reductionSyms, reductionMod);
1141  printBlockArgRegion(p, op, region, args);
1142 }
1143 
1145  Region &region,
1146  ValueRange taskReductionVars,
1147  TypeRange taskReductionTypes,
1148  DenseBoolArrayAttr taskReductionByref,
1149  ArrayAttr taskReductionSyms) {
1150  AllRegionPrintArgs args;
1151  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1152  taskReductionByref, taskReductionSyms);
1153  printBlockArgRegion(p, op, region, args);
1154 }
1155 
1157  Region &region,
1158  ValueRange useDeviceAddrVars,
1159  TypeRange useDeviceAddrTypes,
1160  ValueRange useDevicePtrVars,
1161  TypeRange useDevicePtrTypes) {
1162  AllRegionPrintArgs args;
1163  args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1164  args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1165  printBlockArgRegion(p, op, region, args);
1166 }
1167 
1168 /// Verifies Reduction Clause
1169 static LogicalResult
1170 verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1171  OperandRange reductionVars,
1172  std::optional<ArrayRef<bool>> reductionByref) {
1173  if (!reductionVars.empty()) {
1174  if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1175  return op->emitOpError()
1176  << "expected as many reduction symbol references "
1177  "as reduction variables";
1178  if (reductionByref && reductionByref->size() != reductionVars.size())
1179  return op->emitError() << "expected as many reduction variable by "
1180  "reference attributes as reduction variables";
1181  } else {
1182  if (reductionSyms)
1183  return op->emitOpError() << "unexpected reduction symbol references";
1184  return success();
1185  }
1186 
1187  // TODO: The followings should be done in
1188  // SymbolUserOpInterface::verifySymbolUses.
1189  DenseSet<Value> accumulators;
1190  for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1191  Value accum = std::get<0>(args);
1192 
1193  if (!accumulators.insert(accum).second)
1194  return op->emitOpError() << "accumulator variable used more than once";
1195 
1196  Type varType = accum.getType();
1197  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1198  auto decl =
1199  SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1200  if (!decl)
1201  return op->emitOpError() << "expected symbol reference " << symbolRef
1202  << " to point to a reduction declaration";
1203 
1204  if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1205  return op->emitOpError()
1206  << "expected accumulator (" << varType
1207  << ") to be the same type as reduction declaration ("
1208  << decl.getAccumulatorType() << ")";
1209  }
1210 
1211  return success();
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // Parser, printer and verifier for Copyprivate
1216 //===----------------------------------------------------------------------===//
1217 
1218 /// copyprivate-entry-list ::= copyprivate-entry
1219 /// | copyprivate-entry-list `,` copyprivate-entry
1220 /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1221 static ParseResult parseCopyprivate(
1222  OpAsmParser &parser,
1224  SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1226  if (failed(parser.parseCommaSeparatedList([&]() {
1227  if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1228  parser.parseArrow() ||
1229  parser.parseAttribute(symsVec.emplace_back()) ||
1230  parser.parseColonType(copyprivateTypes.emplace_back()))
1231  return failure();
1232  return success();
1233  })))
1234  return failure();
1235  SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1236  copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1237  return success();
1238 }
1239 
1240 /// Print Copyprivate clause
1242  OperandRange copyprivateVars,
1243  TypeRange copyprivateTypes,
1244  std::optional<ArrayAttr> copyprivateSyms) {
1245  if (!copyprivateSyms.has_value())
1246  return;
1247  llvm::interleaveComma(
1248  llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1249  [&](const auto &args) {
1250  p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1251  << std::get<2>(args);
1252  });
1253 }
1254 
1255 /// Verifies CopyPrivate Clause
1256 static LogicalResult
1258  std::optional<ArrayAttr> copyprivateSyms) {
1259  size_t copyprivateSymsSize =
1260  copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1261  if (copyprivateSymsSize != copyprivateVars.size())
1262  return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1263  << copyprivateVars.size()
1264  << ") and functions (= " << copyprivateSymsSize
1265  << "), both must be equal";
1266  if (!copyprivateSyms.has_value())
1267  return success();
1268 
1269  for (auto copyprivateVarAndSym :
1270  llvm::zip(copyprivateVars, *copyprivateSyms)) {
1271  auto symbolRef =
1272  llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1273  std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1274  funcOp;
1275  if (mlir::func::FuncOp mlirFuncOp =
1276  SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1277  symbolRef))
1278  funcOp = mlirFuncOp;
1279  else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1280  SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1281  op, symbolRef))
1282  funcOp = llvmFuncOp;
1283 
1284  auto getNumArguments = [&] {
1285  return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1286  };
1287 
1288  auto getArgumentType = [&](unsigned i) {
1289  return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1290  *funcOp);
1291  };
1292 
1293  if (!funcOp)
1294  return op->emitOpError() << "expected symbol reference " << symbolRef
1295  << " to point to a copy function";
1296 
1297  if (getNumArguments() != 2)
1298  return op->emitOpError()
1299  << "expected copy function " << symbolRef << " to have 2 operands";
1300 
1301  Type argTy = getArgumentType(0);
1302  if (argTy != getArgumentType(1))
1303  return op->emitOpError() << "expected copy function " << symbolRef
1304  << " arguments to have the same type";
1305 
1306  Type varType = std::get<0>(copyprivateVarAndSym).getType();
1307  if (argTy != varType)
1308  return op->emitOpError()
1309  << "expected copy function arguments' type (" << argTy
1310  << ") to be the same as copyprivate variable's type (" << varType
1311  << ")";
1312  }
1313 
1314  return success();
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // Parser, printer and verifier for DependVarList
1319 //===----------------------------------------------------------------------===//
1320 
1321 /// depend-entry-list ::= depend-entry
1322 /// | depend-entry-list `,` depend-entry
1323 /// depend-entry ::= depend-kind `->` ssa-id `:` type
1324 static ParseResult
1327  SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1329  if (failed(parser.parseCommaSeparatedList([&]() {
1330  StringRef keyword;
1331  if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1332  parser.parseOperand(dependVars.emplace_back()) ||
1333  parser.parseColonType(dependTypes.emplace_back()))
1334  return failure();
1335  if (std::optional<ClauseTaskDepend> keywordDepend =
1336  (symbolizeClauseTaskDepend(keyword)))
1337  kindsVec.emplace_back(
1338  ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1339  else
1340  return failure();
1341  return success();
1342  })))
1343  return failure();
1344  SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1345  dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1346  return success();
1347 }
1348 
1349 /// Print Depend clause
1351  OperandRange dependVars, TypeRange dependTypes,
1352  std::optional<ArrayAttr> dependKinds) {
1353 
1354  for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1355  if (i != 0)
1356  p << ", ";
1357  p << stringifyClauseTaskDepend(
1358  llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1359  .getValue())
1360  << " -> " << dependVars[i] << " : " << dependTypes[i];
1361  }
1362 }
1363 
1364 /// Verifies Depend clause
1365 static LogicalResult verifyDependVarList(Operation *op,
1366  std::optional<ArrayAttr> dependKinds,
1367  OperandRange dependVars) {
1368  if (!dependVars.empty()) {
1369  if (!dependKinds || dependKinds->size() != dependVars.size())
1370  return op->emitOpError() << "expected as many depend values"
1371  " as depend variables";
1372  } else {
1373  if (dependKinds && !dependKinds->empty())
1374  return op->emitOpError() << "unexpected depend values";
1375  return success();
1376  }
1377 
1378  return success();
1379 }
1380 
1381 //===----------------------------------------------------------------------===//
1382 // Parser, printer and verifier for Synchronization Hint (2.17.12)
1383 //===----------------------------------------------------------------------===//
1384 
1385 /// Parses a Synchronization Hint clause. The value of hint is an integer
1386 /// which is a combination of different hints from `omp_sync_hint_t`.
1387 ///
1388 /// hint-clause = `hint` `(` hint-value `)`
1389 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1390  IntegerAttr &hintAttr) {
1391  StringRef hintKeyword;
1392  int64_t hint = 0;
1393  if (succeeded(parser.parseOptionalKeyword("none"))) {
1394  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1395  return success();
1396  }
1397  auto parseKeyword = [&]() -> ParseResult {
1398  if (failed(parser.parseKeyword(&hintKeyword)))
1399  return failure();
1400  if (hintKeyword == "uncontended")
1401  hint |= 1;
1402  else if (hintKeyword == "contended")
1403  hint |= 2;
1404  else if (hintKeyword == "nonspeculative")
1405  hint |= 4;
1406  else if (hintKeyword == "speculative")
1407  hint |= 8;
1408  else
1409  return parser.emitError(parser.getCurrentLocation())
1410  << hintKeyword << " is not a valid hint";
1411  return success();
1412  };
1413  if (parser.parseCommaSeparatedList(parseKeyword))
1414  return failure();
1415  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1416  return success();
1417 }
1418 
1419 /// Prints a Synchronization Hint clause
1421  IntegerAttr hintAttr) {
1422  int64_t hint = hintAttr.getInt();
1423 
1424  if (hint == 0) {
1425  p << "none";
1426  return;
1427  }
1428 
1429  // Helper function to get n-th bit from the right end of `value`
1430  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1431 
1432  bool uncontended = bitn(hint, 0);
1433  bool contended = bitn(hint, 1);
1434  bool nonspeculative = bitn(hint, 2);
1435  bool speculative = bitn(hint, 3);
1436 
1437  SmallVector<StringRef> hints;
1438  if (uncontended)
1439  hints.push_back("uncontended");
1440  if (contended)
1441  hints.push_back("contended");
1442  if (nonspeculative)
1443  hints.push_back("nonspeculative");
1444  if (speculative)
1445  hints.push_back("speculative");
1446 
1447  llvm::interleaveComma(hints, p);
1448 }
1449 
1450 /// Verifies a synchronization hint clause
1451 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1452 
1453  // Helper function to get n-th bit from the right end of `value`
1454  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1455 
1456  bool uncontended = bitn(hint, 0);
1457  bool contended = bitn(hint, 1);
1458  bool nonspeculative = bitn(hint, 2);
1459  bool speculative = bitn(hint, 3);
1460 
1461  if (uncontended && contended)
1462  return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1463  "omp_sync_hint_contended cannot be combined";
1464  if (nonspeculative && speculative)
1465  return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1466  "omp_sync_hint_speculative cannot be combined.";
1467  return success();
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // Parser, printer and verifier for Target
1472 //===----------------------------------------------------------------------===//
1473 
1474 // Helper function to get bitwise AND of `value` and 'flag'
1475 uint64_t mapTypeToBitFlag(uint64_t value,
1476  llvm::omp::OpenMPOffloadMappingFlags flag) {
1477  return value & llvm::to_underlying(flag);
1478 }
1479 
1480 /// Parses a map_entries map type from a string format back into its numeric
1481 /// value.
1482 ///
1483 /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1484 /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1485 static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1486  llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1487  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1488 
1489  // This simply verifies the correct keyword is read in, the
1490  // keyword itself is stored inside of the operation
1491  auto parseTypeAndMod = [&]() -> ParseResult {
1492  StringRef mapTypeMod;
1493  if (parser.parseKeyword(&mapTypeMod))
1494  return failure();
1495 
1496  if (mapTypeMod == "always")
1497  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1498 
1499  if (mapTypeMod == "implicit")
1500  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1501 
1502  if (mapTypeMod == "ompx_hold")
1503  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1504 
1505  if (mapTypeMod == "close")
1506  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1507 
1508  if (mapTypeMod == "present")
1509  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1510 
1511  if (mapTypeMod == "to")
1512  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1513 
1514  if (mapTypeMod == "from")
1515  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1516 
1517  if (mapTypeMod == "tofrom")
1518  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1519  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1520 
1521  if (mapTypeMod == "delete")
1522  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1523 
1524  if (mapTypeMod == "return_param")
1525  mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1526 
1527  return success();
1528  };
1529 
1530  if (parser.parseCommaSeparatedList(parseTypeAndMod))
1531  return failure();
1532 
1533  mapType = parser.getBuilder().getIntegerAttr(
1534  parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1535  llvm::to_underlying(mapTypeBits));
1536 
1537  return success();
1538 }
1539 
1540 /// Prints a map_entries map type from its numeric value out into its string
1541 /// format.
1543  IntegerAttr mapType) {
1544  uint64_t mapTypeBits = mapType.getUInt();
1545 
1546  bool emitAllocRelease = true;
1548 
1549  // handling of always, close, present placed at the beginning of the string
1550  // to aid readability
1551  if (mapTypeToBitFlag(mapTypeBits,
1552  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1553  mapTypeStrs.push_back("always");
1554  if (mapTypeToBitFlag(mapTypeBits,
1555  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1556  mapTypeStrs.push_back("implicit");
1557  if (mapTypeToBitFlag(mapTypeBits,
1558  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1559  mapTypeStrs.push_back("ompx_hold");
1560  if (mapTypeToBitFlag(mapTypeBits,
1561  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1562  mapTypeStrs.push_back("close");
1563  if (mapTypeToBitFlag(mapTypeBits,
1564  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1565  mapTypeStrs.push_back("present");
1566 
1567  // special handling of to/from/tofrom/delete and release/alloc, release +
1568  // alloc are the abscense of one of the other flags, whereas tofrom requires
1569  // both the to and from flag to be set.
1570  bool to = mapTypeToBitFlag(mapTypeBits,
1571  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1572  bool from = mapTypeToBitFlag(
1573  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1574  if (to && from) {
1575  emitAllocRelease = false;
1576  mapTypeStrs.push_back("tofrom");
1577  } else if (from) {
1578  emitAllocRelease = false;
1579  mapTypeStrs.push_back("from");
1580  } else if (to) {
1581  emitAllocRelease = false;
1582  mapTypeStrs.push_back("to");
1583  }
1584  if (mapTypeToBitFlag(mapTypeBits,
1585  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1586  emitAllocRelease = false;
1587  mapTypeStrs.push_back("delete");
1588  }
1589  if (mapTypeToBitFlag(
1590  mapTypeBits,
1591  llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1592  emitAllocRelease = false;
1593  mapTypeStrs.push_back("return_param");
1594  }
1595  if (emitAllocRelease)
1596  mapTypeStrs.push_back("exit_release_or_enter_alloc");
1597 
1598  for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1599  p << mapTypeStrs[i];
1600  if (i + 1 < mapTypeStrs.size()) {
1601  p << ", ";
1602  }
1603  }
1604 }
1605 
1606 static ParseResult parseMembersIndex(OpAsmParser &parser,
1607  ArrayAttr &membersIdx) {
1608  SmallVector<Attribute> values, memberIdxs;
1609 
1610  auto parseIndices = [&]() -> ParseResult {
1611  int64_t value;
1612  if (parser.parseInteger(value))
1613  return failure();
1614  values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1615  APInt(64, value, /*isSigned=*/false)));
1616  return success();
1617  };
1618 
1619  do {
1620  if (failed(parser.parseLSquare()))
1621  return failure();
1622 
1623  if (parser.parseCommaSeparatedList(parseIndices))
1624  return failure();
1625 
1626  if (failed(parser.parseRSquare()))
1627  return failure();
1628 
1629  memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1630  values.clear();
1631  } while (succeeded(parser.parseOptionalComma()));
1632 
1633  if (!memberIdxs.empty())
1634  membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1635 
1636  return success();
1637 }
1638 
1639 static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1640  ArrayAttr membersIdx) {
1641  if (!membersIdx)
1642  return;
1643 
1644  llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1645  p << "[";
1646  auto memberIdx = cast<ArrayAttr>(v);
1647  llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1648  p << cast<IntegerAttr>(v2).getInt();
1649  });
1650  p << "]";
1651  });
1652 }
1653 
1655  VariableCaptureKindAttr mapCaptureType) {
1656  std::string typeCapStr;
1657  llvm::raw_string_ostream typeCap(typeCapStr);
1658  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1659  typeCap << "ByRef";
1660  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1661  typeCap << "ByCopy";
1662  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1663  typeCap << "VLAType";
1664  if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1665  typeCap << "This";
1666  p << typeCapStr;
1667 }
1668 
1669 static ParseResult parseCaptureType(OpAsmParser &parser,
1670  VariableCaptureKindAttr &mapCaptureType) {
1671  StringRef mapCaptureKey;
1672  if (parser.parseKeyword(&mapCaptureKey))
1673  return failure();
1674 
1675  if (mapCaptureKey == "This")
1676  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1677  parser.getContext(), mlir::omp::VariableCaptureKind::This);
1678  if (mapCaptureKey == "ByRef")
1679  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1680  parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1681  if (mapCaptureKey == "ByCopy")
1682  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1683  parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1684  if (mapCaptureKey == "VLAType")
1685  mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1686  parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1687 
1688  return success();
1689 }
1690 
1691 static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1694 
1695  for (auto mapOp : mapVars) {
1696  if (!mapOp.getDefiningOp())
1697  return emitError(op->getLoc(), "missing map operation");
1698 
1699  if (auto mapInfoOp =
1700  mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1701  uint64_t mapTypeBits = mapInfoOp.getMapType();
1702 
1703  bool to = mapTypeToBitFlag(
1704  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1705  bool from = mapTypeToBitFlag(
1706  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1707  bool del = mapTypeToBitFlag(
1708  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1709 
1710  bool always = mapTypeToBitFlag(
1711  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1712  bool close = mapTypeToBitFlag(
1713  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1714  bool implicit = mapTypeToBitFlag(
1715  mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1716 
1717  if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1718  return emitError(op->getLoc(),
1719  "to, from, tofrom and alloc map types are permitted");
1720 
1721  if (isa<TargetEnterDataOp>(op) && (from || del))
1722  return emitError(op->getLoc(), "to and alloc map types are permitted");
1723 
1724  if (isa<TargetExitDataOp>(op) && to)
1725  return emitError(op->getLoc(),
1726  "from, release and delete map types are permitted");
1727 
1728  if (isa<TargetUpdateOp>(op)) {
1729  if (del) {
1730  return emitError(op->getLoc(),
1731  "at least one of to or from map types must be "
1732  "specified, other map types are not permitted");
1733  }
1734 
1735  if (!to && !from) {
1736  return emitError(op->getLoc(),
1737  "at least one of to or from map types must be "
1738  "specified, other map types are not permitted");
1739  }
1740 
1741  auto updateVar = mapInfoOp.getVarPtr();
1742 
1743  if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1744  (from && updateToVars.contains(updateVar))) {
1745  return emitError(
1746  op->getLoc(),
1747  "either to or from map types can be specified, not both");
1748  }
1749 
1750  if (always || close || implicit) {
1751  return emitError(
1752  op->getLoc(),
1753  "present, mapper and iterator map type modifiers are permitted");
1754  }
1755 
1756  to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1757  }
1758  } else if (!isa<DeclareMapperInfoOp>(op)) {
1759  return emitError(op->getLoc(),
1760  "map argument is not a map entry operation");
1761  }
1762  }
1763 
1764  return success();
1765 }
1766 
1767 static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1768  std::optional<DenseI64ArrayAttr> privateMapIndices =
1769  targetOp.getPrivateMapsAttr();
1770 
1771  // None of the private operands are mapped.
1772  if (!privateMapIndices.has_value() || !privateMapIndices.value())
1773  return success();
1774 
1775  OperandRange privateVars = targetOp.getPrivateVars();
1776 
1777  if (privateMapIndices.value().size() !=
1778  static_cast<int64_t>(privateVars.size()))
1779  return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1780  "`private_maps` attribute mismatch");
1781 
1782  return success();
1783 }
1784 
1785 //===----------------------------------------------------------------------===//
1786 // MapInfoOp
1787 //===----------------------------------------------------------------------===//
1788 
1789 static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1790  StringRef clauseName,
1791  OperandRange vars) {
1792  for (Value var : vars)
1793  if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1794  return op->emitOpError()
1795  << "'" << clauseName
1796  << "' arguments must be defined by 'omp.map.info' ops";
1797  return success();
1798 }
1799 
1800 LogicalResult MapInfoOp::verify() {
1801  if (getMapperId() &&
1802  !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1803  *this, getMapperIdAttr())) {
1804  return emitError("invalid mapper id");
1805  }
1806 
1807  if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1808  return failure();
1809 
1810  return success();
1811 }
1812 
1813 //===----------------------------------------------------------------------===//
1814 // TargetDataOp
1815 //===----------------------------------------------------------------------===//
1816 
1817 void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1818  const TargetDataOperands &clauses) {
1819  TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1820  clauses.mapVars, clauses.useDeviceAddrVars,
1821  clauses.useDevicePtrVars);
1822 }
1823 
1824 LogicalResult TargetDataOp::verify() {
1825  if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1826  getUseDeviceAddrVars().empty()) {
1827  return ::emitError(this->getLoc(),
1828  "At least one of map, use_device_ptr_vars, or "
1829  "use_device_addr_vars operand must be present");
1830  }
1831 
1832  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1833  getUseDevicePtrVars())))
1834  return failure();
1835 
1836  if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1837  getUseDeviceAddrVars())))
1838  return failure();
1839 
1840  return verifyMapClause(*this, getMapVars());
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // TargetEnterDataOp
1845 //===----------------------------------------------------------------------===//
1846 
1847 void TargetEnterDataOp::build(
1848  OpBuilder &builder, OperationState &state,
1849  const TargetEnterExitUpdateDataOperands &clauses) {
1850  MLIRContext *ctx = builder.getContext();
1851  TargetEnterDataOp::build(builder, state,
1852  makeArrayAttr(ctx, clauses.dependKinds),
1853  clauses.dependVars, clauses.device, clauses.ifExpr,
1854  clauses.mapVars, clauses.nowait);
1855 }
1856 
1857 LogicalResult TargetEnterDataOp::verify() {
1858  LogicalResult verifyDependVars =
1859  verifyDependVarList(*this, getDependKinds(), getDependVars());
1860  return failed(verifyDependVars) ? verifyDependVars
1861  : verifyMapClause(*this, getMapVars());
1862 }
1863 
1864 //===----------------------------------------------------------------------===//
1865 // TargetExitDataOp
1866 //===----------------------------------------------------------------------===//
1867 
1868 void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1869  const TargetEnterExitUpdateDataOperands &clauses) {
1870  MLIRContext *ctx = builder.getContext();
1871  TargetExitDataOp::build(builder, state,
1872  makeArrayAttr(ctx, clauses.dependKinds),
1873  clauses.dependVars, clauses.device, clauses.ifExpr,
1874  clauses.mapVars, clauses.nowait);
1875 }
1876 
1877 LogicalResult TargetExitDataOp::verify() {
1878  LogicalResult verifyDependVars =
1879  verifyDependVarList(*this, getDependKinds(), getDependVars());
1880  return failed(verifyDependVars) ? verifyDependVars
1881  : verifyMapClause(*this, getMapVars());
1882 }
1883 
1884 //===----------------------------------------------------------------------===//
1885 // TargetUpdateOp
1886 //===----------------------------------------------------------------------===//
1887 
1888 void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1889  const TargetEnterExitUpdateDataOperands &clauses) {
1890  MLIRContext *ctx = builder.getContext();
1891  TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1892  clauses.dependVars, clauses.device, clauses.ifExpr,
1893  clauses.mapVars, clauses.nowait);
1894 }
1895 
1896 LogicalResult TargetUpdateOp::verify() {
1897  LogicalResult verifyDependVars =
1898  verifyDependVarList(*this, getDependKinds(), getDependVars());
1899  return failed(verifyDependVars) ? verifyDependVars
1900  : verifyMapClause(*this, getMapVars());
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // TargetOp
1905 //===----------------------------------------------------------------------===//
1906 
1907 void TargetOp::build(OpBuilder &builder, OperationState &state,
1908  const TargetOperands &clauses) {
1909  MLIRContext *ctx = builder.getContext();
1910  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1911  // inReductionByref, inReductionSyms.
1912  TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1913  clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1914  clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1915  clauses.hostEvalVars, clauses.ifExpr,
1916  /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1917  /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1918  clauses.mapVars, clauses.nowait, clauses.privateVars,
1919  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1920  /*private_maps=*/nullptr);
1921 }
1922 
1923 LogicalResult TargetOp::verify() {
1924  if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1925  return failure();
1926 
1927  if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1928  getHasDeviceAddrVars())))
1929  return failure();
1930 
1931  if (failed(verifyMapClause(*this, getMapVars())))
1932  return failure();
1933 
1934  return verifyPrivateVarsMapping(*this);
1935 }
1936 
1937 LogicalResult TargetOp::verifyRegions() {
1938  auto teamsOps = getOps<TeamsOp>();
1939  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1940  return emitError("target containing multiple 'omp.teams' nested ops");
1941 
1942  // Check that host_eval values are only used in legal ways.
1943  Operation *capturedOp = getInnermostCapturedOmpOp();
1944  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1945  for (Value hostEvalArg :
1946  cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1947  for (Operation *user : hostEvalArg.getUsers()) {
1948  if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1949  if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1950  teamsOp.getNumTeamsUpper(),
1951  teamsOp.getThreadLimit()},
1952  hostEvalArg))
1953  continue;
1954 
1955  return emitOpError() << "host_eval argument only legal as 'num_teams' "
1956  "and 'thread_limit' in 'omp.teams'";
1957  }
1958  if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1959  if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1960  parallelOp->isAncestor(capturedOp) &&
1961  hostEvalArg == parallelOp.getNumThreads())
1962  continue;
1963 
1964  return emitOpError()
1965  << "host_eval argument only legal as 'num_threads' in "
1966  "'omp.parallel' when representing target SPMD";
1967  }
1968  if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1969  if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
1970  loopNestOp.getOperation() == capturedOp &&
1971  (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1972  llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1973  llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1974  continue;
1975 
1976  return emitOpError() << "host_eval argument only legal as loop bounds "
1977  "and steps in 'omp.loop_nest' when trip count "
1978  "must be evaluated in the host";
1979  }
1980 
1981  return emitOpError() << "host_eval argument illegal use in '"
1982  << user->getName() << "' operation";
1983  }
1984  }
1985  return success();
1986 }
1987 
1988 static Operation *
1989 findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1990  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
1991  assert(rootOp && "expected valid operation");
1992 
1993  Dialect *ompDialect = rootOp->getDialect();
1994  Operation *capturedOp = nullptr;
1995  DominanceInfo domInfo;
1996 
1997  // Process in pre-order to check operations from outermost to innermost,
1998  // ensuring we only enter the region of an operation if it meets the criteria
1999  // for being captured. We stop the exploration of nested operations as soon as
2000  // we process a region holding no operations to be captured.
2001  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2002  if (op == rootOp)
2003  return WalkResult::advance();
2004 
2005  // Ignore operations of other dialects or omp operations with no regions,
2006  // because these will only be checked if they are siblings of an omp
2007  // operation that can potentially be captured.
2008  bool isOmpDialect = op->getDialect() == ompDialect;
2009  bool hasRegions = op->getNumRegions() > 0;
2010  if (!isOmpDialect || !hasRegions)
2011  return WalkResult::skip();
2012 
2013  // This operation cannot be captured if it can be executed more than once
2014  // (i.e. its block's successors can reach it) or if it's not guaranteed to
2015  // be executed before all exits of the region (i.e. it doesn't dominate all
2016  // blocks with no successors reachable from the entry block).
2017  if (checkSingleMandatoryExec) {
2018  Region *parentRegion = op->getParentRegion();
2019  Block *parentBlock = op->getBlock();
2020 
2021  for (Block *successor : parentBlock->getSuccessors())
2022  if (successor->isReachable(parentBlock))
2023  return WalkResult::interrupt();
2024 
2025  for (Block &block : *parentRegion)
2026  if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2027  !domInfo.dominates(parentBlock, &block))
2028  return WalkResult::interrupt();
2029  }
2030 
2031  // Don't capture this op if it has a not-allowed sibling, and stop recursing
2032  // into nested operations.
2033  for (Operation &sibling : op->getParentRegion()->getOps())
2034  if (&sibling != op && !siblingAllowedFn(&sibling))
2035  return WalkResult::interrupt();
2036 
2037  // Don't continue capturing nested operations if we reach an omp.loop_nest.
2038  // Otherwise, process the contents of this operation.
2039  capturedOp = op;
2040  return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2041  : WalkResult::advance();
2042  });
2043 
2044  return capturedOp;
2045 }
2046 
2047 Operation *TargetOp::getInnermostCapturedOmpOp() {
2048  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2049 
2050  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2051  // effects, but don't include a memory write effect.
2052  return findCapturedOmpOp(
2053  *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2054  if (!sibling)
2055  return false;
2056 
2057  if (ompDialect == sibling->getDialect())
2058  return sibling->hasTrait<OpTrait::IsTerminator>();
2059 
2060  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2062  effects;
2063  memOp.getEffects(effects);
2064  return !llvm::any_of(
2065  effects, [&](MemoryEffects::EffectInstance &effect) {
2066  return isa<MemoryEffects::Write>(effect.getEffect()) &&
2067  isa<SideEffects::AutomaticAllocationScopeResource>(
2068  effect.getResource());
2069  });
2070  }
2071  return true;
2072  });
2073 }
2074 
2075 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2076  // A non-null captured op is only valid if it resides inside of a TargetOp
2077  // and is the result of calling getInnermostCapturedOmpOp() on it.
2078  TargetOp targetOp =
2079  capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2080  assert((!capturedOp ||
2081  (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2082  "unexpected captured op");
2083 
2084  // If it's not capturing a loop, it's a default target region.
2085  if (!isa_and_present<LoopNestOp>(capturedOp))
2086  return TargetRegionFlags::generic;
2087 
2088  // Get the innermost non-simd loop wrapper.
2089  SmallVector<LoopWrapperInterface> loopWrappers;
2090  cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2091  assert(!loopWrappers.empty());
2092 
2093  LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2094  if (isa<SimdOp>(innermostWrapper))
2095  innermostWrapper = std::next(innermostWrapper);
2096 
2097  auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2098  if (numWrappers != 1 && numWrappers != 2)
2099  return TargetRegionFlags::generic;
2100 
2101  // Detect target-teams-distribute-parallel-wsloop[-simd].
2102  if (numWrappers == 2) {
2103  if (!isa<WsloopOp>(innermostWrapper))
2104  return TargetRegionFlags::generic;
2105 
2106  innermostWrapper = std::next(innermostWrapper);
2107  if (!isa<DistributeOp>(innermostWrapper))
2108  return TargetRegionFlags::generic;
2109 
2110  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2111  if (!isa_and_present<ParallelOp>(parallelOp))
2112  return TargetRegionFlags::generic;
2113 
2114  Operation *teamsOp = parallelOp->getParentOp();
2115  if (!isa_and_present<TeamsOp>(teamsOp))
2116  return TargetRegionFlags::generic;
2117 
2118  if (teamsOp->getParentOp() == targetOp.getOperation())
2119  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2120  }
2121  // Detect target-teams-distribute[-simd] and target-teams-loop.
2122  else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2123  Operation *teamsOp = (*innermostWrapper)->getParentOp();
2124  if (!isa_and_present<TeamsOp>(teamsOp))
2125  return TargetRegionFlags::generic;
2126 
2127  if (teamsOp->getParentOp() != targetOp.getOperation())
2128  return TargetRegionFlags::generic;
2129 
2130  if (isa<LoopOp>(innermostWrapper))
2131  return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2132 
2133  // Find single immediately nested captured omp.parallel and add spmd flag
2134  // (generic-spmd case).
2135  //
2136  // TODO: This shouldn't have to be done here, as it is too easy to break.
2137  // The openmp-opt pass should be updated to be able to promote kernels like
2138  // this from "Generic" to "Generic-SPMD". However, the use of the
2139  // `kmpc_distribute_static_loop` family of functions produced by the
2140  // OMPIRBuilder for these kernels prevents that from working.
2141  Dialect *ompDialect = targetOp->getDialect();
2142  Operation *nestedCapture = findCapturedOmpOp(
2143  capturedOp, /*checkSingleMandatoryExec=*/false,
2144  [&](Operation *sibling) {
2145  return sibling && (ompDialect != sibling->getDialect() ||
2146  sibling->hasTrait<OpTrait::IsTerminator>());
2147  });
2148 
2149  TargetRegionFlags result =
2150  TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2151 
2152  if (!nestedCapture)
2153  return result;
2154 
2155  while (nestedCapture->getParentOp() != capturedOp)
2156  nestedCapture = nestedCapture->getParentOp();
2157 
2158  return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2159  : result;
2160  }
2161  // Detect target-parallel-wsloop[-simd].
2162  else if (isa<WsloopOp>(innermostWrapper)) {
2163  Operation *parallelOp = (*innermostWrapper)->getParentOp();
2164  if (!isa_and_present<ParallelOp>(parallelOp))
2165  return TargetRegionFlags::generic;
2166 
2167  if (parallelOp->getParentOp() == targetOp.getOperation())
2168  return TargetRegionFlags::spmd;
2169  }
2170 
2171  return TargetRegionFlags::generic;
2172 }
2173 
2174 //===----------------------------------------------------------------------===//
2175 // ParallelOp
2176 //===----------------------------------------------------------------------===//
2177 
2178 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2179  ArrayRef<NamedAttribute> attributes) {
2180  ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2181  /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2182  /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2183  /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
2184  /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2185  /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2186  state.addAttributes(attributes);
2187 }
2188 
2189 void ParallelOp::build(OpBuilder &builder, OperationState &state,
2190  const ParallelOperands &clauses) {
2191  MLIRContext *ctx = builder.getContext();
2192  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2193  clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2194  makeArrayAttr(ctx, clauses.privateSyms),
2195  clauses.procBindKind, clauses.reductionMod,
2196  clauses.reductionVars,
2197  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2198  makeArrayAttr(ctx, clauses.reductionSyms));
2199 }
2200 
2201 template <typename OpType>
2202 static LogicalResult verifyPrivateVarList(OpType &op) {
2203  auto privateVars = op.getPrivateVars();
2204  auto privateSyms = op.getPrivateSymsAttr();
2205 
2206  if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2207  return success();
2208 
2209  auto numPrivateVars = privateVars.size();
2210  auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2211 
2212  if (numPrivateVars != numPrivateSyms)
2213  return op.emitError() << "inconsistent number of private variables and "
2214  "privatizer op symbols, private vars: "
2215  << numPrivateVars
2216  << " vs. privatizer op symbols: " << numPrivateSyms;
2217 
2218  for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2219  Type varType = std::get<0>(privateVarInfo).getType();
2220  SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2221  PrivateClauseOp privatizerOp =
2222  SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2223 
2224  if (privatizerOp == nullptr)
2225  return op.emitError() << "failed to lookup privatizer op with symbol: '"
2226  << privateSym << "'";
2227 
2228  Type privatizerType = privatizerOp.getArgType();
2229 
2230  if (privatizerType && (varType != privatizerType))
2231  return op.emitError()
2232  << "type mismatch between a "
2233  << (privatizerOp.getDataSharingType() ==
2234  DataSharingClauseType::Private
2235  ? "private"
2236  : "firstprivate")
2237  << " variable and its privatizer op, var type: " << varType
2238  << " vs. privatizer op type: " << privatizerType;
2239  }
2240 
2241  return success();
2242 }
2243 
2244 LogicalResult ParallelOp::verify() {
2245  if (getAllocateVars().size() != getAllocatorVars().size())
2246  return emitError(
2247  "expected equal sizes for allocate and allocator variables");
2248 
2249  if (failed(verifyPrivateVarList(*this)))
2250  return failure();
2251 
2252  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2253  getReductionByref());
2254 }
2255 
2256 LogicalResult ParallelOp::verifyRegions() {
2257  auto distChildOps = getOps<DistributeOp>();
2258  int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2259  if (numDistChildOps > 1)
2260  return emitError()
2261  << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2262 
2263  if (numDistChildOps == 1) {
2264  if (!isComposite())
2265  return emitError()
2266  << "'omp.composite' attribute missing from composite operation";
2267 
2268  auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2269  Operation &distributeOp = **distChildOps.begin();
2270  for (Operation &childOp : getOps()) {
2271  if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2272  continue;
2273 
2274  if (!childOp.hasTrait<OpTrait::IsTerminator>())
2275  return emitError() << "unexpected OpenMP operation inside of composite "
2276  "'omp.parallel': "
2277  << childOp.getName();
2278  }
2279  } else if (isComposite()) {
2280  return emitError()
2281  << "'omp.composite' attribute present in non-composite operation";
2282  }
2283  return success();
2284 }
2285 
2286 //===----------------------------------------------------------------------===//
2287 // TeamsOp
2288 //===----------------------------------------------------------------------===//
2289 
2291  while ((op = op->getParentOp()))
2292  if (isa<OpenMPDialect>(op->getDialect()))
2293  return false;
2294  return true;
2295 }
2296 
2297 void TeamsOp::build(OpBuilder &builder, OperationState &state,
2298  const TeamsOperands &clauses) {
2299  MLIRContext *ctx = builder.getContext();
2300  // TODO Store clauses in op: privateVars, privateSyms.
2301  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2302  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2303  /*private_vars=*/{}, /*private_syms=*/nullptr,
2304  clauses.reductionMod, clauses.reductionVars,
2305  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2306  makeArrayAttr(ctx, clauses.reductionSyms),
2307  clauses.threadLimit);
2308 }
2309 
2310 LogicalResult TeamsOp::verify() {
2311  // Check parent region
2312  // TODO If nested inside of a target region, also check that it does not
2313  // contain any statements, declarations or directives other than this
2314  // omp.teams construct. The issue is how to support the initialization of
2315  // this operation's own arguments (allow SSA values across omp.target?).
2316  Operation *op = getOperation();
2317  if (!isa<TargetOp>(op->getParentOp()) &&
2319  return emitError("expected to be nested inside of omp.target or not nested "
2320  "in any OpenMP dialect operations");
2321 
2322  // Check for num_teams clause restrictions
2323  if (auto numTeamsLowerBound = getNumTeamsLower()) {
2324  auto numTeamsUpperBound = getNumTeamsUpper();
2325  if (!numTeamsUpperBound)
2326  return emitError("expected num_teams upper bound to be defined if the "
2327  "lower bound is defined");
2328  if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2329  return emitError(
2330  "expected num_teams upper bound and lower bound to be the same type");
2331  }
2332 
2333  // Check for allocate clause restrictions
2334  if (getAllocateVars().size() != getAllocatorVars().size())
2335  return emitError(
2336  "expected equal sizes for allocate and allocator variables");
2337 
2338  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2339  getReductionByref());
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // SectionOp
2344 //===----------------------------------------------------------------------===//
2345 
2346 OperandRange SectionOp::getPrivateVars() {
2347  return getParentOp().getPrivateVars();
2348 }
2349 
2350 OperandRange SectionOp::getReductionVars() {
2351  return getParentOp().getReductionVars();
2352 }
2353 
2354 //===----------------------------------------------------------------------===//
2355 // SectionsOp
2356 //===----------------------------------------------------------------------===//
2357 
2358 void SectionsOp::build(OpBuilder &builder, OperationState &state,
2359  const SectionsOperands &clauses) {
2360  MLIRContext *ctx = builder.getContext();
2361  // TODO Store clauses in op: privateVars, privateSyms.
2362  SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2363  clauses.nowait, /*private_vars=*/{},
2364  /*private_syms=*/nullptr, clauses.reductionMod,
2365  clauses.reductionVars,
2366  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2367  makeArrayAttr(ctx, clauses.reductionSyms));
2368 }
2369 
2370 LogicalResult SectionsOp::verify() {
2371  if (getAllocateVars().size() != getAllocatorVars().size())
2372  return emitError(
2373  "expected equal sizes for allocate and allocator variables");
2374 
2375  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2376  getReductionByref());
2377 }
2378 
2379 LogicalResult SectionsOp::verifyRegions() {
2380  for (auto &inst : *getRegion().begin()) {
2381  if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2382  return emitOpError()
2383  << "expected omp.section op or terminator op inside region";
2384  }
2385  }
2386 
2387  return success();
2388 }
2389 
2390 //===----------------------------------------------------------------------===//
2391 // SingleOp
2392 //===----------------------------------------------------------------------===//
2393 
2394 void SingleOp::build(OpBuilder &builder, OperationState &state,
2395  const SingleOperands &clauses) {
2396  MLIRContext *ctx = builder.getContext();
2397  // TODO Store clauses in op: privateVars, privateSyms.
2398  SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2399  clauses.copyprivateVars,
2400  makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2401  /*private_vars=*/{}, /*private_syms=*/nullptr);
2402 }
2403 
2404 LogicalResult SingleOp::verify() {
2405  // Check for allocate clause restrictions
2406  if (getAllocateVars().size() != getAllocatorVars().size())
2407  return emitError(
2408  "expected equal sizes for allocate and allocator variables");
2409 
2410  return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2411  getCopyprivateSyms());
2412 }
2413 
2414 //===----------------------------------------------------------------------===//
2415 // WorkshareOp
2416 //===----------------------------------------------------------------------===//
2417 
2418 void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2419  const WorkshareOperands &clauses) {
2420  WorkshareOp::build(builder, state, clauses.nowait);
2421 }
2422 
2423 //===----------------------------------------------------------------------===//
2424 // WorkshareLoopWrapperOp
2425 //===----------------------------------------------------------------------===//
2426 
2427 LogicalResult WorkshareLoopWrapperOp::verify() {
2428  if (!(*this)->getParentOfType<WorkshareOp>())
2429  return emitOpError() << "must be nested in an omp.workshare";
2430  return success();
2431 }
2432 
2433 LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2434  if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2435  getNestedWrapper())
2436  return emitOpError() << "expected to be a standalone loop wrapper";
2437 
2438  return success();
2439 }
2440 
2441 //===----------------------------------------------------------------------===//
2442 // LoopWrapperInterface
2443 //===----------------------------------------------------------------------===//
2444 
2445 LogicalResult LoopWrapperInterface::verifyImpl() {
2446  Operation *op = this->getOperation();
2447  if (!op->hasTrait<OpTrait::NoTerminator>() ||
2449  return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2450  "and `SingleBlock` traits";
2451 
2452  if (op->getNumRegions() != 1)
2453  return emitOpError() << "loop wrapper does not contain exactly one region";
2454 
2455  Region &region = op->getRegion(0);
2456  if (range_size(region.getOps()) != 1)
2457  return emitOpError()
2458  << "loop wrapper does not contain exactly one nested op";
2459 
2460  Operation &firstOp = *region.op_begin();
2461  if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2462  return emitOpError() << "nested in loop wrapper is not another loop "
2463  "wrapper or `omp.loop_nest`";
2464 
2465  return success();
2466 }
2467 
2468 //===----------------------------------------------------------------------===//
2469 // LoopOp
2470 //===----------------------------------------------------------------------===//
2471 
2472 void LoopOp::build(OpBuilder &builder, OperationState &state,
2473  const LoopOperands &clauses) {
2474  MLIRContext *ctx = builder.getContext();
2475 
2476  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2477  makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
2478  clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
2479  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2480  makeArrayAttr(ctx, clauses.reductionSyms));
2481 }
2482 
2483 LogicalResult LoopOp::verify() {
2484  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2485  getReductionByref());
2486 }
2487 
2488 LogicalResult LoopOp::verifyRegions() {
2489  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2490  getNestedWrapper())
2491  return emitOpError() << "expected to be a standalone loop wrapper";
2492 
2493  return success();
2494 }
2495 
2496 //===----------------------------------------------------------------------===//
2497 // WsloopOp
2498 //===----------------------------------------------------------------------===//
2499 
2500 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2501  ArrayRef<NamedAttribute> attributes) {
2502  build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2503  /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2504  /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2505  /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2506  /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2507  /*reduction_byref=*/nullptr,
2508  /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2509  /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2510  /*schedule_simd=*/false);
2511  state.addAttributes(attributes);
2512 }
2513 
2514 void WsloopOp::build(OpBuilder &builder, OperationState &state,
2515  const WsloopOperands &clauses) {
2516  MLIRContext *ctx = builder.getContext();
2517  // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
2518  // privateSyms.
2519  WsloopOp::build(builder, state,
2520  /*allocate_vars=*/{}, /*allocator_vars=*/{},
2521  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
2522  clauses.order, clauses.orderMod, clauses.ordered,
2523  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2524  clauses.reductionMod, clauses.reductionVars,
2525  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2526  makeArrayAttr(ctx, clauses.reductionSyms),
2527  clauses.scheduleKind, clauses.scheduleChunk,
2528  clauses.scheduleMod, clauses.scheduleSimd);
2529 }
2530 
2531 LogicalResult WsloopOp::verify() {
2532  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2533  getReductionByref());
2534 }
2535 
2536 LogicalResult WsloopOp::verifyRegions() {
2537  bool isCompositeChildLeaf =
2538  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2539 
2540  if (LoopWrapperInterface nested = getNestedWrapper()) {
2541  if (!isComposite())
2542  return emitError()
2543  << "'omp.composite' attribute missing from composite wrapper";
2544 
2545  // Check for the allowed leaf constructs that may appear in a composite
2546  // construct directly after DO/FOR.
2547  if (!isa<SimdOp>(nested))
2548  return emitError() << "only supported nested wrapper is 'omp.simd'";
2549 
2550  } else if (isComposite() && !isCompositeChildLeaf) {
2551  return emitError()
2552  << "'omp.composite' attribute present in non-composite wrapper";
2553  } else if (!isComposite() && isCompositeChildLeaf) {
2554  return emitError()
2555  << "'omp.composite' attribute missing from composite wrapper";
2556  }
2557 
2558  return success();
2559 }
2560 
2561 //===----------------------------------------------------------------------===//
2562 // Simd construct [2.9.3.1]
2563 //===----------------------------------------------------------------------===//
2564 
2565 void SimdOp::build(OpBuilder &builder, OperationState &state,
2566  const SimdOperands &clauses) {
2567  MLIRContext *ctx = builder.getContext();
2568  // TODO Store clauses in op: linearVars, linearStepVars, privateVars,
2569  // privateSyms.
2570  SimdOp::build(builder, state, clauses.alignedVars,
2571  makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2572  /*linear_vars=*/{}, /*linear_step_vars=*/{},
2573  clauses.nontemporalVars, clauses.order, clauses.orderMod,
2574  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2575  clauses.reductionMod, clauses.reductionVars,
2576  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2577  makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2578  clauses.simdlen);
2579 }
2580 
2581 LogicalResult SimdOp::verify() {
2582  if (getSimdlen().has_value() && getSafelen().has_value() &&
2583  getSimdlen().value() > getSafelen().value())
2584  return emitOpError()
2585  << "simdlen clause and safelen clause are both present, but the "
2586  "simdlen value is not less than or equal to safelen value";
2587 
2588  if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2589  return failure();
2590 
2591  if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2592  return failure();
2593 
2594  bool isCompositeChildLeaf =
2595  llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2596 
2597  if (!isComposite() && isCompositeChildLeaf)
2598  return emitError()
2599  << "'omp.composite' attribute missing from composite wrapper";
2600 
2601  if (isComposite() && !isCompositeChildLeaf)
2602  return emitError()
2603  << "'omp.composite' attribute present in non-composite wrapper";
2604 
2605  return success();
2606 }
2607 
2608 LogicalResult SimdOp::verifyRegions() {
2609  if (getNestedWrapper())
2610  return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2611 
2612  return success();
2613 }
2614 
2615 //===----------------------------------------------------------------------===//
2616 // Distribute construct [2.9.4.1]
2617 //===----------------------------------------------------------------------===//
2618 
2619 void DistributeOp::build(OpBuilder &builder, OperationState &state,
2620  const DistributeOperands &clauses) {
2621  DistributeOp::build(builder, state, clauses.allocateVars,
2622  clauses.allocatorVars, clauses.distScheduleStatic,
2623  clauses.distScheduleChunkSize, clauses.order,
2624  clauses.orderMod, clauses.privateVars,
2625  makeArrayAttr(builder.getContext(), clauses.privateSyms));
2626 }
2627 
2628 LogicalResult DistributeOp::verify() {
2629  if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2630  return emitOpError() << "chunk size set without "
2631  "dist_schedule_static being present";
2632 
2633  if (getAllocateVars().size() != getAllocatorVars().size())
2634  return emitError(
2635  "expected equal sizes for allocate and allocator variables");
2636 
2637  return success();
2638 }
2639 
2640 LogicalResult DistributeOp::verifyRegions() {
2641  if (LoopWrapperInterface nested = getNestedWrapper()) {
2642  if (!isComposite())
2643  return emitError()
2644  << "'omp.composite' attribute missing from composite wrapper";
2645  // Check for the allowed leaf constructs that may appear in a composite
2646  // construct directly after DISTRIBUTE.
2647  if (isa<WsloopOp>(nested)) {
2648  Operation *parentOp = (*this)->getParentOp();
2649  if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2650  !cast<ComposableOpInterface>(parentOp).isComposite()) {
2651  return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2652  "when a composite 'omp.parallel' is the direct "
2653  "parent";
2654  }
2655  } else if (!isa<SimdOp>(nested))
2656  return emitError() << "only supported nested wrappers are 'omp.simd' and "
2657  "'omp.wsloop'";
2658  } else if (isComposite()) {
2659  return emitError()
2660  << "'omp.composite' attribute present in non-composite wrapper";
2661  }
2662 
2663  return success();
2664 }
2665 
2666 //===----------------------------------------------------------------------===//
2667 // DeclareMapperOp / DeclareMapperInfoOp
2668 //===----------------------------------------------------------------------===//
2669 
2670 LogicalResult DeclareMapperInfoOp::verify() {
2671  return verifyMapClause(*this, getMapVars());
2672 }
2673 
2674 LogicalResult DeclareMapperOp::verifyRegions() {
2675  if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2676  getRegion().getBlocks().front().getTerminator()))
2677  return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2678 
2679  return success();
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // DeclareReductionOp
2684 //===----------------------------------------------------------------------===//
2685 
2686 LogicalResult DeclareReductionOp::verifyRegions() {
2687  if (!getAllocRegion().empty()) {
2688  for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2689  if (yieldOp.getResults().size() != 1 ||
2690  yieldOp.getResults().getTypes()[0] != getType())
2691  return emitOpError() << "expects alloc region to yield a value "
2692  "of the reduction type";
2693  }
2694  }
2695 
2696  if (getInitializerRegion().empty())
2697  return emitOpError() << "expects non-empty initializer region";
2698  Block &initializerEntryBlock = getInitializerRegion().front();
2699 
2700  if (initializerEntryBlock.getNumArguments() == 1) {
2701  if (!getAllocRegion().empty())
2702  return emitOpError() << "expects two arguments to the initializer region "
2703  "when an allocation region is used";
2704  } else if (initializerEntryBlock.getNumArguments() == 2) {
2705  if (getAllocRegion().empty())
2706  return emitOpError() << "expects one argument to the initializer region "
2707  "when no allocation region is used";
2708  } else {
2709  return emitOpError()
2710  << "expects one or two arguments to the initializer region";
2711  }
2712 
2713  for (mlir::Value arg : initializerEntryBlock.getArguments())
2714  if (arg.getType() != getType())
2715  return emitOpError() << "expects initializer region argument to match "
2716  "the reduction type";
2717 
2718  for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2719  if (yieldOp.getResults().size() != 1 ||
2720  yieldOp.getResults().getTypes()[0] != getType())
2721  return emitOpError() << "expects initializer region to yield a value "
2722  "of the reduction type";
2723  }
2724 
2725  if (getReductionRegion().empty())
2726  return emitOpError() << "expects non-empty reduction region";
2727  Block &reductionEntryBlock = getReductionRegion().front();
2728  if (reductionEntryBlock.getNumArguments() != 2 ||
2729  reductionEntryBlock.getArgumentTypes()[0] !=
2730  reductionEntryBlock.getArgumentTypes()[1] ||
2731  reductionEntryBlock.getArgumentTypes()[0] != getType())
2732  return emitOpError() << "expects reduction region with two arguments of "
2733  "the reduction type";
2734  for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2735  if (yieldOp.getResults().size() != 1 ||
2736  yieldOp.getResults().getTypes()[0] != getType())
2737  return emitOpError() << "expects reduction region to yield a value "
2738  "of the reduction type";
2739  }
2740 
2741  if (!getAtomicReductionRegion().empty()) {
2742  Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2743  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2744  atomicReductionEntryBlock.getArgumentTypes()[0] !=
2745  atomicReductionEntryBlock.getArgumentTypes()[1])
2746  return emitOpError() << "expects atomic reduction region with two "
2747  "arguments of the same type";
2748  auto ptrType = llvm::dyn_cast<PointerLikeType>(
2749  atomicReductionEntryBlock.getArgumentTypes()[0]);
2750  if (!ptrType ||
2751  (ptrType.getElementType() && ptrType.getElementType() != getType()))
2752  return emitOpError() << "expects atomic reduction region arguments to "
2753  "be accumulators containing the reduction type";
2754  }
2755 
2756  if (getCleanupRegion().empty())
2757  return success();
2758  Block &cleanupEntryBlock = getCleanupRegion().front();
2759  if (cleanupEntryBlock.getNumArguments() != 1 ||
2760  cleanupEntryBlock.getArgument(0).getType() != getType())
2761  return emitOpError() << "expects cleanup region with one argument "
2762  "of the reduction type";
2763 
2764  return success();
2765 }
2766 
2767 //===----------------------------------------------------------------------===//
2768 // TaskOp
2769 //===----------------------------------------------------------------------===//
2770 
2771 void TaskOp::build(OpBuilder &builder, OperationState &state,
2772  const TaskOperands &clauses) {
2773  MLIRContext *ctx = builder.getContext();
2774  TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2775  makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2776  clauses.final, clauses.ifExpr, clauses.inReductionVars,
2777  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2778  makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2779  clauses.priority, /*private_vars=*/clauses.privateVars,
2780  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2781  clauses.untied, clauses.eventHandle);
2782 }
2783 
2784 LogicalResult TaskOp::verify() {
2785  LogicalResult verifyDependVars =
2786  verifyDependVarList(*this, getDependKinds(), getDependVars());
2787  return failed(verifyDependVars)
2788  ? verifyDependVars
2789  : verifyReductionVarList(*this, getInReductionSyms(),
2790  getInReductionVars(),
2791  getInReductionByref());
2792 }
2793 
2794 //===----------------------------------------------------------------------===//
2795 // TaskgroupOp
2796 //===----------------------------------------------------------------------===//
2797 
2798 void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2799  const TaskgroupOperands &clauses) {
2800  MLIRContext *ctx = builder.getContext();
2801  TaskgroupOp::build(builder, state, clauses.allocateVars,
2802  clauses.allocatorVars, clauses.taskReductionVars,
2803  makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2804  makeArrayAttr(ctx, clauses.taskReductionSyms));
2805 }
2806 
2807 LogicalResult TaskgroupOp::verify() {
2808  return verifyReductionVarList(*this, getTaskReductionSyms(),
2809  getTaskReductionVars(),
2810  getTaskReductionByref());
2811 }
2812 
2813 //===----------------------------------------------------------------------===//
2814 // TaskloopOp
2815 //===----------------------------------------------------------------------===//
2816 
2817 void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2818  const TaskloopOperands &clauses) {
2819  MLIRContext *ctx = builder.getContext();
2820  TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2821  clauses.final, clauses.grainsizeMod, clauses.grainsize,
2822  clauses.ifExpr, clauses.inReductionVars,
2823  makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2824  makeArrayAttr(ctx, clauses.inReductionSyms),
2825  clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
2826  clauses.numTasks, clauses.priority,
2827  /*private_vars=*/clauses.privateVars,
2828  /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2829  clauses.reductionMod, clauses.reductionVars,
2830  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2831  makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2832 }
2833 
2834 LogicalResult TaskloopOp::verify() {
2835  if (getAllocateVars().size() != getAllocatorVars().size())
2836  return emitError(
2837  "expected equal sizes for allocate and allocator variables");
2838  if (failed(verifyReductionVarList(*this, getReductionSyms(),
2839  getReductionVars(), getReductionByref())) ||
2840  failed(verifyReductionVarList(*this, getInReductionSyms(),
2841  getInReductionVars(),
2842  getInReductionByref())))
2843  return failure();
2844 
2845  if (!getReductionVars().empty() && getNogroup())
2846  return emitError("if a reduction clause is present on the taskloop "
2847  "directive, the nogroup clause must not be specified");
2848  for (auto var : getReductionVars()) {
2849  if (llvm::is_contained(getInReductionVars(), var))
2850  return emitError("the same list item cannot appear in both a reduction "
2851  "and an in_reduction clause");
2852  }
2853 
2854  if (getGrainsize() && getNumTasks()) {
2855  return emitError(
2856  "the grainsize clause and num_tasks clause are mutually exclusive and "
2857  "may not appear on the same taskloop directive");
2858  }
2859 
2860  return success();
2861 }
2862 
2863 LogicalResult TaskloopOp::verifyRegions() {
2864  if (LoopWrapperInterface nested = getNestedWrapper()) {
2865  if (!isComposite())
2866  return emitError()
2867  << "'omp.composite' attribute missing from composite wrapper";
2868 
2869  // Check for the allowed leaf constructs that may appear in a composite
2870  // construct directly after TASKLOOP.
2871  if (!isa<SimdOp>(nested))
2872  return emitError() << "only supported nested wrapper is 'omp.simd'";
2873  } else if (isComposite()) {
2874  return emitError()
2875  << "'omp.composite' attribute present in non-composite wrapper";
2876  }
2877 
2878  return success();
2879 }
2880 
2881 //===----------------------------------------------------------------------===//
2882 // LoopNestOp
2883 //===----------------------------------------------------------------------===//
2884 
2885 ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2886  // Parse an opening `(` followed by induction variables followed by `)`
2889  Type loopVarType;
2890  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2891  parser.parseColonType(loopVarType) ||
2892  // Parse loop bounds.
2893  parser.parseEqual() ||
2894  parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2895  parser.parseKeyword("to") ||
2896  parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2897  return failure();
2898 
2899  for (auto &iv : ivs)
2900  iv.type = loopVarType;
2901 
2902  // Parse "inclusive" flag.
2903  if (succeeded(parser.parseOptionalKeyword("inclusive")))
2904  result.addAttribute("loop_inclusive",
2905  UnitAttr::get(parser.getBuilder().getContext()));
2906 
2907  // Parse step values.
2909  if (parser.parseKeyword("step") ||
2910  parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2911  return failure();
2912 
2913  // Parse the body.
2914  Region *region = result.addRegion();
2915  if (parser.parseRegion(*region, ivs))
2916  return failure();
2917 
2918  // Resolve operands.
2919  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2920  parser.resolveOperands(ubs, loopVarType, result.operands) ||
2921  parser.resolveOperands(steps, loopVarType, result.operands))
2922  return failure();
2923 
2924  // Parse the optional attribute list.
2925  return parser.parseOptionalAttrDict(result.attributes);
2926 }
2927 
2929  Region &region = getRegion();
2930  auto args = region.getArguments();
2931  p << " (" << args << ") : " << args[0].getType() << " = ("
2932  << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2933  if (getLoopInclusive())
2934  p << "inclusive ";
2935  p << "step (" << getLoopSteps() << ") ";
2936  p.printRegion(region, /*printEntryBlockArgs=*/false);
2937 }
2938 
2939 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2940  const LoopNestOperands &clauses) {
2941  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2942  clauses.loopUpperBounds, clauses.loopSteps,
2943  clauses.loopInclusive);
2944 }
2945 
2946 LogicalResult LoopNestOp::verify() {
2947  if (getLoopLowerBounds().empty())
2948  return emitOpError() << "must represent at least one loop";
2949 
2950  if (getLoopLowerBounds().size() != getIVs().size())
2951  return emitOpError() << "number of range arguments and IVs do not match";
2952 
2953  for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2954  if (lb.getType() != iv.getType())
2955  return emitOpError()
2956  << "range argument type does not match corresponding IV type";
2957  }
2958 
2959  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2960  return emitOpError() << "expects parent op to be a loop wrapper";
2961 
2962  return success();
2963 }
2964 
2965 void LoopNestOp::gatherWrappers(
2967  Operation *parent = (*this)->getParentOp();
2968  while (auto wrapper =
2969  llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2970  wrappers.push_back(wrapper);
2971  parent = parent->getParentOp();
2972  }
2973 }
2974 
2975 //===----------------------------------------------------------------------===//
2976 // Critical construct (2.17.1)
2977 //===----------------------------------------------------------------------===//
2978 
2979 void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
2980  const CriticalDeclareOperands &clauses) {
2981  CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
2982 }
2983 
2984 LogicalResult CriticalDeclareOp::verify() {
2985  return verifySynchronizationHint(*this, getHint());
2986 }
2987 
2988 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2989  if (getNameAttr()) {
2990  SymbolRefAttr symbolRef = getNameAttr();
2991  auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
2992  *this, symbolRef);
2993  if (!decl) {
2994  return emitOpError() << "expected symbol reference " << symbolRef
2995  << " to point to a critical declaration";
2996  }
2997  }
2998 
2999  return success();
3000 }
3001 
3002 //===----------------------------------------------------------------------===//
3003 // Ordered construct
3004 //===----------------------------------------------------------------------===//
3005 
3006 static LogicalResult verifyOrderedParent(Operation &op) {
3007  bool hasRegion = op.getNumRegions() > 0;
3008  auto loopOp = op.getParentOfType<LoopNestOp>();
3009  if (!loopOp) {
3010  if (hasRegion)
3011  return success();
3012 
3013  // TODO: Consider if this needs to be the case only for the standalone
3014  // variant of the ordered construct.
3015  return op.emitOpError() << "must be nested inside of a loop";
3016  }
3017 
3018  Operation *wrapper = loopOp->getParentOp();
3019  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3020  IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3021  if (!orderedAttr)
3022  return op.emitOpError() << "the enclosing worksharing-loop region must "
3023  "have an ordered clause";
3024 
3025  if (hasRegion && orderedAttr.getInt() != 0)
3026  return op.emitOpError() << "the enclosing loop's ordered clause must not "
3027  "have a parameter present";
3028 
3029  if (!hasRegion && orderedAttr.getInt() == 0)
3030  return op.emitOpError() << "the enclosing loop's ordered clause must "
3031  "have a parameter present";
3032  } else if (!isa<SimdOp>(wrapper)) {
3033  return op.emitOpError() << "must be nested inside of a worksharing, simd "
3034  "or worksharing simd loop";
3035  }
3036  return success();
3037 }
3038 
3039 void OrderedOp::build(OpBuilder &builder, OperationState &state,
3040  const OrderedOperands &clauses) {
3041  OrderedOp::build(builder, state, clauses.doacrossDependType,
3042  clauses.doacrossNumLoops, clauses.doacrossDependVars);
3043 }
3044 
3045 LogicalResult OrderedOp::verify() {
3046  if (failed(verifyOrderedParent(**this)))
3047  return failure();
3048 
3049  auto wrapper = (*this)->getParentOfType<WsloopOp>();
3050  if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3051  return emitOpError() << "number of variables in depend clause does not "
3052  << "match number of iteration variables in the "
3053  << "doacross loop";
3054 
3055  return success();
3056 }
3057 
3058 void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3059  const OrderedRegionOperands &clauses) {
3060  OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3061 }
3062 
3063 LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3064 
3065 //===----------------------------------------------------------------------===//
3066 // TaskwaitOp
3067 //===----------------------------------------------------------------------===//
3068 
3069 void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3070  const TaskwaitOperands &clauses) {
3071  // TODO Store clauses in op: dependKinds, dependVars, nowait.
3072  TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3073  /*depend_vars=*/{}, /*nowait=*/nullptr);
3074 }
3075 
3076 //===----------------------------------------------------------------------===//
3077 // Verifier for AtomicReadOp
3078 //===----------------------------------------------------------------------===//
3079 
3080 LogicalResult AtomicReadOp::verify() {
3081  if (verifyCommon().failed())
3082  return mlir::failure();
3083 
3084  if (auto mo = getMemoryOrder()) {
3085  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3086  *mo == ClauseMemoryOrderKind::Release) {
3087  return emitError(
3088  "memory-order must not be acq_rel or release for atomic reads");
3089  }
3090  }
3091  return verifySynchronizationHint(*this, getHint());
3092 }
3093 
3094 //===----------------------------------------------------------------------===//
3095 // Verifier for AtomicWriteOp
3096 //===----------------------------------------------------------------------===//
3097 
3098 LogicalResult AtomicWriteOp::verify() {
3099  if (verifyCommon().failed())
3100  return mlir::failure();
3101 
3102  if (auto mo = getMemoryOrder()) {
3103  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3104  *mo == ClauseMemoryOrderKind::Acquire) {
3105  return emitError(
3106  "memory-order must not be acq_rel or acquire for atomic writes");
3107  }
3108  }
3109  return verifySynchronizationHint(*this, getHint());
3110 }
3111 
3112 //===----------------------------------------------------------------------===//
3113 // Verifier for AtomicUpdateOp
3114 //===----------------------------------------------------------------------===//
3115 
3116 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3117  PatternRewriter &rewriter) {
3118  if (op.isNoOp()) {
3119  rewriter.eraseOp(op);
3120  return success();
3121  }
3122  if (Value writeVal = op.getWriteOpVal()) {
3123  rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3124  op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3125  return success();
3126  }
3127  return failure();
3128 }
3129 
3130 LogicalResult AtomicUpdateOp::verify() {
3131  if (verifyCommon().failed())
3132  return mlir::failure();
3133 
3134  if (auto mo = getMemoryOrder()) {
3135  if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3136  *mo == ClauseMemoryOrderKind::Acquire) {
3137  return emitError(
3138  "memory-order must not be acq_rel or acquire for atomic updates");
3139  }
3140  }
3141 
3142  return verifySynchronizationHint(*this, getHint());
3143 }
3144 
3145 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3146 
3147 //===----------------------------------------------------------------------===//
3148 // Verifier for AtomicCaptureOp
3149 //===----------------------------------------------------------------------===//
3150 
3151 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3152  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3153  return op;
3154  return dyn_cast<AtomicReadOp>(getSecondOp());
3155 }
3156 
3157 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3158  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3159  return op;
3160  return dyn_cast<AtomicWriteOp>(getSecondOp());
3161 }
3162 
3163 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3164  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3165  return op;
3166  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3167 }
3168 
3169 LogicalResult AtomicCaptureOp::verify() {
3170  return verifySynchronizationHint(*this, getHint());
3171 }
3172 
3173 LogicalResult AtomicCaptureOp::verifyRegions() {
3174  if (verifyRegionsCommon().failed())
3175  return mlir::failure();
3176 
3177  if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3178  return emitOpError(
3179  "operations inside capture region must not have hint clause");
3180 
3181  if (getFirstOp()->getAttr("memory_order") ||
3182  getSecondOp()->getAttr("memory_order"))
3183  return emitOpError(
3184  "operations inside capture region must not have memory_order clause");
3185  return success();
3186 }
3187 
3188 //===----------------------------------------------------------------------===//
3189 // CancelOp
3190 //===----------------------------------------------------------------------===//
3191 
3192 void CancelOp::build(OpBuilder &builder, OperationState &state,
3193  const CancelOperands &clauses) {
3194  CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3195 }
3196 
3198  Operation *parent = thisOp->getParentOp();
3199  while (parent) {
3200  if (parent->getDialect() == thisOp->getDialect())
3201  return parent;
3202  parent = parent->getParentOp();
3203  }
3204  return nullptr;
3205 }
3206 
3207 LogicalResult CancelOp::verify() {
3208  ClauseCancellationConstructType cct = getCancelDirective();
3209  // The next OpenMP operation in the chain of parents
3210  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3211  if (!structuralParent)
3212  return emitOpError() << "Orphaned cancel construct";
3213 
3214  if ((cct == ClauseCancellationConstructType::Parallel) &&
3215  !mlir::isa<ParallelOp>(structuralParent)) {
3216  return emitOpError() << "cancel parallel must appear "
3217  << "inside a parallel region";
3218  }
3219  if (cct == ClauseCancellationConstructType::Loop) {
3220  // structural parent will be omp.loop_nest, directly nested inside
3221  // omp.wsloop
3222  auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3223 
3224  if (!wsloopOp) {
3225  return emitOpError()
3226  << "cancel loop must appear inside a worksharing-loop region";
3227  }
3228  if (wsloopOp.getNowaitAttr()) {
3229  return emitError() << "A worksharing construct that is canceled "
3230  << "must not have a nowait clause";
3231  }
3232  if (wsloopOp.getOrderedAttr()) {
3233  return emitError() << "A worksharing construct that is canceled "
3234  << "must not have an ordered clause";
3235  }
3236 
3237  } else if (cct == ClauseCancellationConstructType::Sections) {
3238  // structural parent will be an omp.section, directly nested inside
3239  // omp.sections
3240  auto sectionsOp =
3241  mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3242  if (!sectionsOp) {
3243  return emitOpError() << "cancel sections must appear "
3244  << "inside a sections region";
3245  }
3246  if (sectionsOp.getNowait()) {
3247  return emitError() << "A sections construct that is canceled "
3248  << "must not have a nowait clause";
3249  }
3250  }
3251  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3252  (!mlir::isa<omp::TaskOp>(structuralParent) &&
3253  !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
3254  return emitOpError() << "cancel taskgroup must appear "
3255  << "inside a task region";
3256  }
3257  return success();
3258 }
3259 
3260 //===----------------------------------------------------------------------===//
3261 // CancellationPointOp
3262 //===----------------------------------------------------------------------===//
3263 
3264 void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3265  const CancellationPointOperands &clauses) {
3266  CancellationPointOp::build(builder, state, clauses.cancelDirective);
3267 }
3268 
3269 LogicalResult CancellationPointOp::verify() {
3270  ClauseCancellationConstructType cct = getCancelDirective();
3271  // The next OpenMP operation in the chain of parents
3272  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3273  if (!structuralParent)
3274  return emitOpError() << "Orphaned cancellation point";
3275 
3276  if ((cct == ClauseCancellationConstructType::Parallel) &&
3277  !mlir::isa<ParallelOp>(structuralParent)) {
3278  return emitOpError() << "cancellation point parallel must appear "
3279  << "inside a parallel region";
3280  }
3281  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3282  // find the wsloop
3283  if ((cct == ClauseCancellationConstructType::Loop) &&
3284  !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3285  return emitOpError() << "cancellation point loop must appear "
3286  << "inside a worksharing-loop region";
3287  }
3288  if ((cct == ClauseCancellationConstructType::Sections) &&
3289  !mlir::isa<omp::SectionOp>(structuralParent)) {
3290  return emitOpError() << "cancellation point sections must appear "
3291  << "inside a sections region";
3292  }
3293  if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3294  !mlir::isa<omp::TaskOp>(structuralParent)) {
3295  return emitOpError() << "cancellation point taskgroup must appear "
3296  << "inside a task region";
3297  }
3298  return success();
3299 }
3300 
3301 //===----------------------------------------------------------------------===//
3302 // MapBoundsOp
3303 //===----------------------------------------------------------------------===//
3304 
3305 LogicalResult MapBoundsOp::verify() {
3306  auto extent = getExtent();
3307  auto upperbound = getUpperBound();
3308  if (!extent && !upperbound)
3309  return emitError("expected extent or upperbound.");
3310  return success();
3311 }
3312 
3313 void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3314  TypeRange /*result_types*/, StringAttr symName,
3315  TypeAttr type) {
3316  PrivateClauseOp::build(
3317  odsBuilder, odsState, symName, type,
3319  DataSharingClauseType::Private));
3320 }
3321 
3322 LogicalResult PrivateClauseOp::verifyRegions() {
3323  Type argType = getArgType();
3324  auto verifyTerminator = [&](Operation *terminator,
3325  bool yieldsValue) -> LogicalResult {
3326  if (!terminator->getBlock()->getSuccessors().empty())
3327  return success();
3328 
3329  if (!llvm::isa<YieldOp>(terminator))
3330  return mlir::emitError(terminator->getLoc())
3331  << "expected exit block terminator to be an `omp.yield` op.";
3332 
3333  YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3334  TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3335 
3336  if (!yieldsValue) {
3337  if (yieldedTypes.empty())
3338  return success();
3339 
3340  return mlir::emitError(terminator->getLoc())
3341  << "Did not expect any values to be yielded.";
3342  }
3343 
3344  if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3345  return success();
3346 
3347  auto error = mlir::emitError(yieldOp.getLoc())
3348  << "Invalid yielded value. Expected type: " << argType
3349  << ", got: ";
3350 
3351  if (yieldedTypes.empty())
3352  error << "None";
3353  else
3354  error << yieldedTypes;
3355 
3356  return error;
3357  };
3358 
3359  auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3360  StringRef regionName,
3361  bool yieldsValue) -> LogicalResult {
3362  assert(!region.empty());
3363 
3364  if (region.getNumArguments() != expectedNumArgs)
3365  return mlir::emitError(region.getLoc())
3366  << "`" << regionName << "`: "
3367  << "expected " << expectedNumArgs
3368  << " region arguments, got: " << region.getNumArguments();
3369 
3370  for (Block &block : region) {
3371  // MLIR will verify the absence of the terminator for us.
3372  if (!block.mightHaveTerminator())
3373  continue;
3374 
3375  if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3376  return failure();
3377  }
3378 
3379  return success();
3380  };
3381 
3382  // Ensure all of the region arguments have the same type
3383  for (Region *region : getRegions())
3384  for (Type ty : region->getArgumentTypes())
3385  if (ty != argType)
3386  return emitError() << "Region argument type mismatch: got " << ty
3387  << " expected " << argType << ".";
3388 
3389  mlir::Region &initRegion = getInitRegion();
3390  if (!initRegion.empty() &&
3391  failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3392  /*yieldsValue=*/true)))
3393  return failure();
3394 
3395  DataSharingClauseType dsType = getDataSharingType();
3396 
3397  if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3398  return emitError("`private` clauses do not require a `copy` region.");
3399 
3400  if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3401  return emitError(
3402  "`firstprivate` clauses require at least a `copy` region.");
3403 
3404  if (dsType == DataSharingClauseType::FirstPrivate &&
3405  failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3406  /*yieldsValue=*/true)))
3407  return failure();
3408 
3409  if (!getDeallocRegion().empty() &&
3410  failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3411  /*yieldsValue=*/false)))
3412  return failure();
3413 
3414  return success();
3415 }
3416 
3417 //===----------------------------------------------------------------------===//
3418 // Spec 5.2: Masked construct (10.5)
3419 //===----------------------------------------------------------------------===//
3420 
3421 void MaskedOp::build(OpBuilder &builder, OperationState &state,
3422  const MaskedOperands &clauses) {
3423  MaskedOp::build(builder, state, clauses.filteredThreadId);
3424 }
3425 
3426 //===----------------------------------------------------------------------===//
3427 // Spec 5.2: Scan construct (5.6)
3428 //===----------------------------------------------------------------------===//
3429 
3430 void ScanOp::build(OpBuilder &builder, OperationState &state,
3431  const ScanOperands &clauses) {
3432  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3433 }
3434 
3435 LogicalResult ScanOp::verify() {
3436  if (hasExclusiveVars() == hasInclusiveVars())
3437  return emitError(
3438  "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3439  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3440  if (parentWsLoopOp.getReductionModAttr() &&
3441  parentWsLoopOp.getReductionModAttr().getValue() ==
3442  ReductionModifier::inscan)
3443  return success();
3444  }
3445  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3446  if (parentSimdOp.getReductionModAttr() &&
3447  parentSimdOp.getReductionModAttr().getValue() ==
3448  ReductionModifier::inscan)
3449  return success();
3450  }
3451  return emitError("SCAN directive needs to be enclosed within a parent "
3452  "worksharing loop construct or SIMD construct with INSCAN "
3453  "reduction modifier");
3454 }
3455 
3456 #define GET_ATTRDEF_CLASSES
3457 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3458 
3459 #define GET_OP_CLASSES
3460 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3461 
3462 #define GET_TYPEDEF_CLASSES
3463 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:752
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1287
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, DenseI64ArrayAttr privateMaps)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr)
uint64_t mapTypeToBitFlag(uint64_t value, llvm::omp::OpenMPOffloadMappingFlags flag)
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType)
Parses a map_entries map type from a string format back into its numeric value.
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 >> &modifiers)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static Operation * getParentInSameDialect(Operation *thisOp)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, DenseI64ArrayAttr &privateMaps)
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool >> reductionByref)
Verifies Reduction Clause.
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static void printMapClause(OpAsmPrinter &p, Operation *op, IntegerAttr mapType)
Prints a map_entries map type from its numeric value out into its string format.
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseMembersIndex(OpAsmParser &parser, ArrayAttr &membersIdx)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
SuccessorRange getSuccessors()
Definition: Block.h:267
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:55
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:204
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:772
This class indicates that the regions associated with this op don't have terminators.
Definition: OpDefinition.h:768
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
BlockArgListType getArguments()
Definition: Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Definition: Dominance.cpp:307
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:880
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.