MLIR  22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
14 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/LogicalResult.h"
26 #include <variant>
27 
28 using namespace mlir;
29 using namespace acc;
30 
31 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36 
37 namespace {
38 
39 static bool isScalarLikeType(Type type) {
40  return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
41 }
42 
43 struct MemRefPointerLikeModel
44  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
45  MemRefType> {
46  Type getElementType(Type pointer) const {
47  return cast<MemRefType>(pointer).getElementType();
48  }
49 
50  mlir::acc::VariableTypeCategory
51  getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
52  Type varType) const {
53  if (auto mappableTy = dyn_cast<MappableType>(varType)) {
54  return mappableTy.getTypeCategory(varPtr);
55  }
56  auto memrefTy = cast<MemRefType>(pointer);
57  if (!memrefTy.hasRank()) {
58  // This memref is unranked - aka it could have any rank, including a
59  // rank of 0 which could mean scalar. For now, return uncategorized.
60  return mlir::acc::VariableTypeCategory::uncategorized;
61  }
62 
63  if (memrefTy.getRank() == 0) {
64  if (isScalarLikeType(memrefTy.getElementType())) {
65  return mlir::acc::VariableTypeCategory::scalar;
66  }
67  // Zero-rank non-scalar - need further analysis to determine the type
68  // category. For now, return uncategorized.
69  return mlir::acc::VariableTypeCategory::uncategorized;
70  }
71 
72  // It has a rank - must be an array.
73  assert(memrefTy.getRank() > 0 && "rank expected to be positive");
74  return mlir::acc::VariableTypeCategory::array;
75  }
76 
77  mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
78  StringRef varName, Type varType, Value originalVar,
79  bool &needsFree) const {
80  auto memrefTy = cast<MemRefType>(pointer);
81 
82  // Check if this is a static memref (all dimensions are known) - if yes
83  // then we can generate an alloca operation.
84  if (memrefTy.hasStaticShape()) {
85  needsFree = false; // alloca doesn't need deallocation
86  return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
87  }
88 
89  // For dynamic memrefs, extract sizes from the original variable if
90  // provided. Otherwise they cannot be handled.
91  if (originalVar && originalVar.getType() == memrefTy &&
92  memrefTy.hasRank()) {
93  SmallVector<Value> dynamicSizes;
94  for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
95  if (memrefTy.isDynamicDim(i)) {
96  // Extract the size of dimension i from the original variable
97  auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
98  auto dimSize =
99  memref::DimOp::create(builder, loc, originalVar, indexValue);
100  dynamicSizes.push_back(dimSize);
101  }
102  // Note: We only add dynamic sizes to the dynamicSizes array
103  // Static dimensions are handled automatically by AllocOp
104  }
105  needsFree = true; // alloc needs deallocation
106  return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
107  .getResult();
108  }
109 
110  // TODO: Unranked not yet supported.
111  return {};
112  }
113 
114  bool genFree(Type pointer, OpBuilder &builder, Location loc,
115  TypedValue<PointerLikeType> varToFree, Value allocRes,
116  Type varType) const {
117  if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
118  // Use allocRes if provided to determine the allocation type
119  Value valueToInspect = allocRes ? allocRes : memrefValue;
120 
121  // Walk through casts to find the original allocation
122  Value currentValue = valueToInspect;
123  Operation *originalAlloc = nullptr;
124 
125  // Follow the chain of operations to find the original allocation
126  // even if a casted result is provided.
127  while (currentValue) {
128  if (auto *definingOp = currentValue.getDefiningOp()) {
129  // Check if this is an allocation operation
130  if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
131  originalAlloc = definingOp;
132  break;
133  }
134 
135  // Check if this is a cast operation we can look through
136  if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
137  currentValue = castOp.getSource();
138  continue;
139  }
140 
141  // Check for other cast-like operations
142  if (auto reinterpretCastOp =
143  dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
144  currentValue = reinterpretCastOp.getSource();
145  continue;
146  }
147 
148  // If we can't look through this operation, stop
149  break;
150  }
151  // This is a block argument or similar - can't trace further.
152  break;
153  }
154 
155  if (originalAlloc) {
156  if (isa<memref::AllocaOp>(originalAlloc)) {
157  // This is an alloca - no dealloc needed, but return true (success)
158  return true;
159  }
160  if (isa<memref::AllocOp>(originalAlloc)) {
161  // This is an alloc - generate dealloc on varToFree
162  memref::DeallocOp::create(builder, loc, memrefValue);
163  return true;
164  }
165  }
166  }
167 
168  return false;
169  }
170 
171  bool genCopy(Type pointer, OpBuilder &builder, Location loc,
172  TypedValue<PointerLikeType> destination,
173  TypedValue<PointerLikeType> source, Type varType) const {
174  // Generate a copy operation between two memrefs
175  auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
176  auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
177 
178  // As per memref documentation, source and destination must have same
179  // element type and shape in order to be compatible. We do not want to fail
180  // with an IR verification error - thus check that before generating the
181  // copy operation.
182  if (destMemref && srcMemref &&
183  destMemref.getType().getElementType() ==
184  srcMemref.getType().getElementType() &&
185  destMemref.getType().getShape() == srcMemref.getType().getShape()) {
186  memref::CopyOp::create(builder, loc, srcMemref, destMemref);
187  return true;
188  }
189 
190  return false;
191  }
192 };
193 
194 struct LLVMPointerPointerLikeModel
195  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
196  LLVM::LLVMPointerType> {
197  Type getElementType(Type pointer) const { return Type(); }
198 };
199 
200 /// Helper function for any of the times we need to modify an ArrayAttr based on
201 /// a device type list. Returns a new ArrayAttr with all of the
202 /// existingDeviceTypes, plus the effective new ones(or an added none if hte new
203 /// list is empty).
204 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
205  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
206  llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
208  if (existingDeviceTypes)
209  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
210 
211  if (newDeviceTypes.empty())
212  deviceTypes.push_back(
214 
215  for (DeviceType dt : newDeviceTypes)
216  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
217 
218  return mlir::ArrayAttr::get(context, deviceTypes);
219 }
220 
221 /// Helper function for any of the times we need to add operands that are
222 /// affected by a device type list. Returns a new ArrayAttr with all of the
223 /// existingDeviceTypes, plus the effective new ones (or an added none, if the
224 /// new list is empty). Additionally, adds the arguments to the argCollection
225 /// the correct number of times. This will also update a 'segments' array, even
226 /// if it won't be used.
227 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
228  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
229  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
230  mlir::MutableOperandRange argCollection,
231  llvm::SmallVector<int32_t> &segments) {
233  if (existingDeviceTypes)
234  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
235 
236  if (newDeviceTypes.empty()) {
237  argCollection.append(arguments);
238  segments.push_back(arguments.size());
239  deviceTypes.push_back(
241  }
242 
243  for (DeviceType dt : newDeviceTypes) {
244  argCollection.append(arguments);
245  segments.push_back(arguments.size());
246  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
247  }
248 
249  return mlir::ArrayAttr::get(context, deviceTypes);
250 }
251 
252 /// Overload for when the 'segments' aren't needed.
253 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
254  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
255  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
256  mlir::MutableOperandRange argCollection) {
258  return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
259  newDeviceTypes, arguments,
260  argCollection, segments);
261 }
262 } // namespace
263 
264 //===----------------------------------------------------------------------===//
265 // OpenACC operations
266 //===----------------------------------------------------------------------===//
267 
268 void OpenACCDialect::initialize() {
269  addOperations<
270 #define GET_OP_LIST
271 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
272  >();
273  addAttributes<
274 #define GET_ATTRDEF_LIST
275 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
276  >();
277  addTypes<
278 #define GET_TYPEDEF_LIST
279 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
280  >();
281 
282  // By attaching interfaces here, we make the OpenACC dialect dependent on
283  // the other dialects. This is probably better than having dialects like LLVM
284  // and memref be dependent on OpenACC.
285  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
286  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
287  *getContext());
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // device_type support helpers
292 //===----------------------------------------------------------------------===//
293 
294 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
295  return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
296 }
297 
298 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
299  mlir::acc::DeviceType deviceType) {
300  if (!hasDeviceTypeValues(arrayAttr))
301  return false;
302 
303  for (auto attr : *arrayAttr) {
304  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
305  if (deviceTypeAttr.getValue() == deviceType)
306  return true;
307  }
308 
309  return false;
310 }
311 
313  std::optional<mlir::ArrayAttr> deviceTypes) {
314  if (!hasDeviceTypeValues(deviceTypes))
315  return;
316 
317  p << "[";
318  llvm::interleaveComma(*deviceTypes, p,
319  [&](mlir::Attribute attr) { p << attr; });
320  p << "]";
321 }
322 
323 static std::optional<unsigned> findSegment(ArrayAttr segments,
324  mlir::acc::DeviceType deviceType) {
325  unsigned segmentIdx = 0;
326  for (auto attr : segments) {
327  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
328  if (deviceTypeAttr.getValue() == deviceType)
329  return std::make_optional(segmentIdx);
330  ++segmentIdx;
331  }
332  return std::nullopt;
333 }
334 
336 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
338  std::optional<llvm::ArrayRef<int32_t>> segments,
339  mlir::acc::DeviceType deviceType) {
340  if (!arrayAttr)
341  return range.take_front(0);
342  if (auto pos = findSegment(*arrayAttr, deviceType)) {
343  int32_t nbOperandsBefore = 0;
344  for (unsigned i = 0; i < *pos; ++i)
345  nbOperandsBefore += (*segments)[i];
346  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
347  }
348  return range.take_front(0);
349 }
350 
351 static mlir::Value
352 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
354  std::optional<llvm::ArrayRef<int32_t>> segments,
355  std::optional<mlir::ArrayAttr> hasWaitDevnum,
356  mlir::acc::DeviceType deviceType) {
357  if (!hasDeviceTypeValues(deviceTypeAttr))
358  return {};
359  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
360  if (hasWaitDevnum->getValue()[*pos])
361  return getValuesFromSegments(deviceTypeAttr, operands, segments,
362  deviceType)
363  .front();
364  return {};
365 }
366 
368 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
370  std::optional<llvm::ArrayRef<int32_t>> segments,
371  std::optional<mlir::ArrayAttr> hasWaitDevnum,
372  mlir::acc::DeviceType deviceType) {
373  auto range =
374  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
375  if (range.empty())
376  return range;
377  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
378  if (hasWaitDevnum && *hasWaitDevnum) {
379  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
380  if (boolAttr.getValue())
381  return range.drop_front(1); // first value is devnum
382  }
383  }
384  return range;
385 }
386 
387 template <typename Op>
388 static LogicalResult checkWaitAndAsyncConflict(Op op) {
389  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
390  ++dtypeInt) {
391  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
392 
393  // The asyncOnly attribute represent the async clause without value.
394  // Therefore the attribute and operand cannot appear at the same time.
395  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
396  op.hasAsyncOnly(dtype))
397  return op.emitError(
398  "asyncOnly attribute cannot appear with asyncOperand");
399 
400  // The wait attribute represent the wait clause without values. Therefore
401  // the attribute and operands cannot appear at the same time.
402  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
403  op.hasWaitOnly(dtype))
404  return op.emitError("wait attribute cannot appear with waitOperands");
405  }
406  return success();
407 }
408 
409 template <typename Op>
410 static LogicalResult checkVarAndVarType(Op op) {
411  if (!op.getVar())
412  return op.emitError("must have var operand");
413 
414  // A variable must have a type that is either pointer-like or mappable.
415  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
416  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
417  return op.emitError("var must be mappable or pointer-like");
418 
419  // When it is a pointer-like type, the varType must capture the target type.
420  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
421  op.getVarType() == op.getVar().getType())
422  return op.emitError("varType must capture the element type of var");
423 
424  return success();
425 }
426 
427 template <typename Op>
428 static LogicalResult checkVarAndAccVar(Op op) {
429  if (op.getVar().getType() != op.getAccVar().getType())
430  return op.emitError("input and output types must match");
431 
432  return success();
433 }
434 
435 template <typename Op>
436 static LogicalResult checkNoModifier(Op op) {
437  if (op.getModifiers() != acc::DataClauseModifier::none)
438  return op.emitError("no data clause modifiers are allowed");
439  return success();
440 }
441 
442 template <typename Op>
443 static LogicalResult
444 checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
445  if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
446  return op.emitError(
447  "invalid data clause modifiers: " +
448  acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
449 
450  return success();
451 }
452 
453 static ParseResult parseVar(mlir::OpAsmParser &parser,
455  // Either `var` or `varPtr` keyword is required.
456  if (failed(parser.parseOptionalKeyword("varPtr"))) {
457  if (failed(parser.parseKeyword("var")))
458  return failure();
459  }
460  if (failed(parser.parseLParen()))
461  return failure();
462  if (failed(parser.parseOperand(var)))
463  return failure();
464 
465  return success();
466 }
467 
469  mlir::Value var) {
470  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
471  p << "varPtr(";
472  else
473  p << "var(";
474  p.printOperand(var);
475 }
476 
477 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
479  mlir::Type &accVarType) {
480  // Either `accVar` or `accPtr` keyword is required.
481  if (failed(parser.parseOptionalKeyword("accPtr"))) {
482  if (failed(parser.parseKeyword("accVar")))
483  return failure();
484  }
485  if (failed(parser.parseLParen()))
486  return failure();
487  if (failed(parser.parseOperand(var)))
488  return failure();
489  if (failed(parser.parseColon()))
490  return failure();
491  if (failed(parser.parseType(accVarType)))
492  return failure();
493  if (failed(parser.parseRParen()))
494  return failure();
495 
496  return success();
497 }
498 
500  mlir::Value accVar, mlir::Type accVarType) {
501  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
502  p << "accPtr(";
503  else
504  p << "accVar(";
505  p.printOperand(accVar);
506  p << " : ";
507  p.printType(accVarType);
508  p << ")";
509 }
510 
511 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
512  mlir::Type &varPtrType,
513  mlir::TypeAttr &varTypeAttr) {
514  if (failed(parser.parseType(varPtrType)))
515  return failure();
516  if (failed(parser.parseRParen()))
517  return failure();
518 
519  if (succeeded(parser.parseOptionalKeyword("varType"))) {
520  if (failed(parser.parseLParen()))
521  return failure();
522  mlir::Type varType;
523  if (failed(parser.parseType(varType)))
524  return failure();
525  varTypeAttr = mlir::TypeAttr::get(varType);
526  if (failed(parser.parseRParen()))
527  return failure();
528  } else {
529  // Set `varType` from the element type of the type of `varPtr`.
530  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
531  varTypeAttr = mlir::TypeAttr::get(
532  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
533  else
534  varTypeAttr = mlir::TypeAttr::get(varPtrType);
535  }
536 
537  return success();
538 }
539 
541  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
542  p.printType(varPtrType);
543  p << ")";
544 
545  // Print the `varType` only if it differs from the element type of
546  // `varPtr`'s type.
547  mlir::Type varType = varTypeAttr.getValue();
548  mlir::Type typeToCheckAgainst =
549  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
550  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
551  : varPtrType;
552  if (typeToCheckAgainst != varType) {
553  p << " varType(";
554  p.printType(varType);
555  p << ")";
556  }
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // DataBoundsOp
561 //===----------------------------------------------------------------------===//
562 LogicalResult acc::DataBoundsOp::verify() {
563  auto extent = getExtent();
564  auto upperbound = getUpperbound();
565  if (!extent && !upperbound)
566  return emitError("expected extent or upperbound.");
567  return success();
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // PrivateOp
572 //===----------------------------------------------------------------------===//
573 LogicalResult acc::PrivateOp::verify() {
574  if (getDataClause() != acc::DataClause::acc_private)
575  return emitError(
576  "data clause associated with private operation must match its intent");
577  if (failed(checkVarAndVarType(*this)))
578  return failure();
579  if (failed(checkNoModifier(*this)))
580  return failure();
581  return success();
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // FirstprivateOp
586 //===----------------------------------------------------------------------===//
587 LogicalResult acc::FirstprivateOp::verify() {
588  if (getDataClause() != acc::DataClause::acc_firstprivate)
589  return emitError("data clause associated with firstprivate operation must "
590  "match its intent");
591  if (failed(checkVarAndVarType(*this)))
592  return failure();
593  if (failed(checkNoModifier(*this)))
594  return failure();
595  return success();
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // ReductionOp
600 //===----------------------------------------------------------------------===//
601 LogicalResult acc::ReductionOp::verify() {
602  if (getDataClause() != acc::DataClause::acc_reduction)
603  return emitError("data clause associated with reduction operation must "
604  "match its intent");
605  if (failed(checkVarAndVarType(*this)))
606  return failure();
607  if (failed(checkNoModifier(*this)))
608  return failure();
609  return success();
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // DevicePtrOp
614 //===----------------------------------------------------------------------===//
615 LogicalResult acc::DevicePtrOp::verify() {
616  if (getDataClause() != acc::DataClause::acc_deviceptr)
617  return emitError("data clause associated with deviceptr operation must "
618  "match its intent");
619  if (failed(checkVarAndVarType(*this)))
620  return failure();
621  if (failed(checkVarAndAccVar(*this)))
622  return failure();
623  if (failed(checkNoModifier(*this)))
624  return failure();
625  return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // PresentOp
630 //===----------------------------------------------------------------------===//
631 LogicalResult acc::PresentOp::verify() {
632  if (getDataClause() != acc::DataClause::acc_present)
633  return emitError(
634  "data clause associated with present operation must match its intent");
635  if (failed(checkVarAndVarType(*this)))
636  return failure();
637  if (failed(checkVarAndAccVar(*this)))
638  return failure();
639  if (failed(checkNoModifier(*this)))
640  return failure();
641  return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // CopyinOp
646 //===----------------------------------------------------------------------===//
647 LogicalResult acc::CopyinOp::verify() {
648  // Test for all clauses this operation can be decomposed from:
649  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
650  getDataClause() != acc::DataClause::acc_copyin_readonly &&
651  getDataClause() != acc::DataClause::acc_copy &&
652  getDataClause() != acc::DataClause::acc_reduction)
653  return emitError(
654  "data clause associated with copyin operation must match its intent"
655  " or specify original clause this operation was decomposed from");
656  if (failed(checkVarAndVarType(*this)))
657  return failure();
658  if (failed(checkVarAndAccVar(*this)))
659  return failure();
660  if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
661  acc::DataClauseModifier::always |
662  acc::DataClauseModifier::capture)))
663  return failure();
664  return success();
665 }
666 
667 bool acc::CopyinOp::isCopyinReadonly() {
668  return getDataClause() == acc::DataClause::acc_copyin_readonly ||
669  acc::bitEnumContainsAny(getModifiers(),
670  acc::DataClauseModifier::readonly);
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // CreateOp
675 //===----------------------------------------------------------------------===//
676 LogicalResult acc::CreateOp::verify() {
677  // Test for all clauses this operation can be decomposed from:
678  if (getDataClause() != acc::DataClause::acc_create &&
679  getDataClause() != acc::DataClause::acc_create_zero &&
680  getDataClause() != acc::DataClause::acc_copyout &&
681  getDataClause() != acc::DataClause::acc_copyout_zero)
682  return emitError(
683  "data clause associated with create operation must match its intent"
684  " or specify original clause this operation was decomposed from");
685  if (failed(checkVarAndVarType(*this)))
686  return failure();
687  if (failed(checkVarAndAccVar(*this)))
688  return failure();
689  // this op is the entry part of copyout, so it also needs to allow all
690  // modifiers allowed on copyout.
691  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
692  acc::DataClauseModifier::always |
693  acc::DataClauseModifier::capture)))
694  return failure();
695  return success();
696 }
697 
698 bool acc::CreateOp::isCreateZero() {
699  // The zero modifier is encoded in the data clause.
700  return getDataClause() == acc::DataClause::acc_create_zero ||
701  getDataClause() == acc::DataClause::acc_copyout_zero ||
702  acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // NoCreateOp
707 //===----------------------------------------------------------------------===//
708 LogicalResult acc::NoCreateOp::verify() {
709  if (getDataClause() != acc::DataClause::acc_no_create)
710  return emitError("data clause associated with no_create operation must "
711  "match its intent");
712  if (failed(checkVarAndVarType(*this)))
713  return failure();
714  if (failed(checkVarAndAccVar(*this)))
715  return failure();
716  if (failed(checkNoModifier(*this)))
717  return failure();
718  return success();
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // AttachOp
723 //===----------------------------------------------------------------------===//
724 LogicalResult acc::AttachOp::verify() {
725  if (getDataClause() != acc::DataClause::acc_attach)
726  return emitError(
727  "data clause associated with attach operation must match its intent");
728  if (failed(checkVarAndVarType(*this)))
729  return failure();
730  if (failed(checkVarAndAccVar(*this)))
731  return failure();
732  if (failed(checkNoModifier(*this)))
733  return failure();
734  return success();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // DeclareDeviceResidentOp
739 //===----------------------------------------------------------------------===//
740 
741 LogicalResult acc::DeclareDeviceResidentOp::verify() {
742  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
743  return emitError("data clause associated with device_resident operation "
744  "must match its intent");
745  if (failed(checkVarAndVarType(*this)))
746  return failure();
747  if (failed(checkVarAndAccVar(*this)))
748  return failure();
749  if (failed(checkNoModifier(*this)))
750  return failure();
751  return success();
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // DeclareLinkOp
756 //===----------------------------------------------------------------------===//
757 
758 LogicalResult acc::DeclareLinkOp::verify() {
759  if (getDataClause() != acc::DataClause::acc_declare_link)
760  return emitError(
761  "data clause associated with link operation must match its intent");
762  if (failed(checkVarAndVarType(*this)))
763  return failure();
764  if (failed(checkVarAndAccVar(*this)))
765  return failure();
766  if (failed(checkNoModifier(*this)))
767  return failure();
768  return success();
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // CopyoutOp
773 //===----------------------------------------------------------------------===//
774 LogicalResult acc::CopyoutOp::verify() {
775  // Test for all clauses this operation can be decomposed from:
776  if (getDataClause() != acc::DataClause::acc_copyout &&
777  getDataClause() != acc::DataClause::acc_copyout_zero &&
778  getDataClause() != acc::DataClause::acc_copy &&
779  getDataClause() != acc::DataClause::acc_reduction)
780  return emitError(
781  "data clause associated with copyout operation must match its intent"
782  " or specify original clause this operation was decomposed from");
783  if (!getVar() || !getAccVar())
784  return emitError("must have both host and device pointers");
785  if (failed(checkVarAndVarType(*this)))
786  return failure();
787  if (failed(checkVarAndAccVar(*this)))
788  return failure();
789  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
790  acc::DataClauseModifier::always |
791  acc::DataClauseModifier::capture)))
792  return failure();
793  return success();
794 }
795 
796 bool acc::CopyoutOp::isCopyoutZero() {
797  return getDataClause() == acc::DataClause::acc_copyout_zero ||
798  acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // DeleteOp
803 //===----------------------------------------------------------------------===//
804 LogicalResult acc::DeleteOp::verify() {
805  // Test for all clauses this operation can be decomposed from:
806  if (getDataClause() != acc::DataClause::acc_delete &&
807  getDataClause() != acc::DataClause::acc_create &&
808  getDataClause() != acc::DataClause::acc_create_zero &&
809  getDataClause() != acc::DataClause::acc_copyin &&
810  getDataClause() != acc::DataClause::acc_copyin_readonly &&
811  getDataClause() != acc::DataClause::acc_present &&
812  getDataClause() != acc::DataClause::acc_no_create &&
813  getDataClause() != acc::DataClause::acc_declare_device_resident &&
814  getDataClause() != acc::DataClause::acc_declare_link)
815  return emitError(
816  "data clause associated with delete operation must match its intent"
817  " or specify original clause this operation was decomposed from");
818  if (!getAccVar())
819  return emitError("must have device pointer");
820  // This op is the exit part of copyin and create - thus allow all modifiers
821  // allowed on either case.
822  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
823  acc::DataClauseModifier::readonly |
824  acc::DataClauseModifier::always |
825  acc::DataClauseModifier::capture)))
826  return failure();
827  return success();
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // DetachOp
832 //===----------------------------------------------------------------------===//
833 LogicalResult acc::DetachOp::verify() {
834  // Test for all clauses this operation can be decomposed from:
835  if (getDataClause() != acc::DataClause::acc_detach &&
836  getDataClause() != acc::DataClause::acc_attach)
837  return emitError(
838  "data clause associated with detach operation must match its intent"
839  " or specify original clause this operation was decomposed from");
840  if (!getAccVar())
841  return emitError("must have device pointer");
842  if (failed(checkNoModifier(*this)))
843  return failure();
844  return success();
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // HostOp
849 //===----------------------------------------------------------------------===//
850 LogicalResult acc::UpdateHostOp::verify() {
851  // Test for all clauses this operation can be decomposed from:
852  if (getDataClause() != acc::DataClause::acc_update_host &&
853  getDataClause() != acc::DataClause::acc_update_self)
854  return emitError(
855  "data clause associated with host operation must match its intent"
856  " or specify original clause this operation was decomposed from");
857  if (!getVar() || !getAccVar())
858  return emitError("must have both host and device pointers");
859  if (failed(checkVarAndVarType(*this)))
860  return failure();
861  if (failed(checkVarAndAccVar(*this)))
862  return failure();
863  if (failed(checkNoModifier(*this)))
864  return failure();
865  return success();
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // DeviceOp
870 //===----------------------------------------------------------------------===//
871 LogicalResult acc::UpdateDeviceOp::verify() {
872  // Test for all clauses this operation can be decomposed from:
873  if (getDataClause() != acc::DataClause::acc_update_device)
874  return emitError(
875  "data clause associated with device operation must match its intent"
876  " or specify original clause this operation was decomposed from");
877  if (failed(checkVarAndVarType(*this)))
878  return failure();
879  if (failed(checkVarAndAccVar(*this)))
880  return failure();
881  if (failed(checkNoModifier(*this)))
882  return failure();
883  return success();
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // UseDeviceOp
888 //===----------------------------------------------------------------------===//
889 LogicalResult acc::UseDeviceOp::verify() {
890  // Test for all clauses this operation can be decomposed from:
891  if (getDataClause() != acc::DataClause::acc_use_device)
892  return emitError(
893  "data clause associated with use_device operation must match its intent"
894  " or specify original clause this operation was decomposed from");
895  if (failed(checkVarAndVarType(*this)))
896  return failure();
897  if (failed(checkVarAndAccVar(*this)))
898  return failure();
899  if (failed(checkNoModifier(*this)))
900  return failure();
901  return success();
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // CacheOp
906 //===----------------------------------------------------------------------===//
907 LogicalResult acc::CacheOp::verify() {
908  // Test for all clauses this operation can be decomposed from:
909  if (getDataClause() != acc::DataClause::acc_cache &&
910  getDataClause() != acc::DataClause::acc_cache_readonly)
911  return emitError(
912  "data clause associated with cache operation must match its intent"
913  " or specify original clause this operation was decomposed from");
914  if (failed(checkVarAndVarType(*this)))
915  return failure();
916  if (failed(checkVarAndAccVar(*this)))
917  return failure();
918  if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
919  return failure();
920  return success();
921 }
922 
923 bool acc::CacheOp::isCacheReadonly() {
924  return getDataClause() == acc::DataClause::acc_cache_readonly ||
925  acc::bitEnumContainsAny(getModifiers(),
926  acc::DataClauseModifier::readonly);
927 }
928 
929 template <typename StructureOp>
930 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
931  unsigned nRegions = 1) {
932 
933  SmallVector<Region *, 2> regions;
934  for (unsigned i = 0; i < nRegions; ++i)
935  regions.push_back(state.addRegion());
936 
937  for (Region *region : regions)
938  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
939  return failure();
940 
941  return success();
942 }
943 
944 static bool isComputeOperation(Operation *op) {
945  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
946 }
947 
948 namespace {
949 /// Pattern to remove operation without region that have constant false `ifCond`
950 /// and remove the condition from the operation if the `ifCond` is a true
951 /// constant.
952 template <typename OpTy>
953 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
955 
956  LogicalResult matchAndRewrite(OpTy op,
957  PatternRewriter &rewriter) const override {
958  // Early return if there is no condition.
959  Value ifCond = op.getIfCond();
960  if (!ifCond)
961  return failure();
962 
963  IntegerAttr constAttr;
964  if (!matchPattern(ifCond, m_Constant(&constAttr)))
965  return failure();
966  if (constAttr.getInt())
967  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
968  else
969  rewriter.eraseOp(op);
970 
971  return success();
972  }
973 };
974 
975 /// Replaces the given op with the contents of the given single-block region,
976 /// using the operands of the block terminator to replace operation results.
977 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
978  Region &region, ValueRange blockArgs = {}) {
979  assert(region.hasOneBlock() && "expected single-block region");
980  Block *block = &region.front();
981  Operation *terminator = block->getTerminator();
982  ValueRange results = terminator->getOperands();
983  rewriter.inlineBlockBefore(block, op, blockArgs);
984  rewriter.replaceOp(op, results);
985  rewriter.eraseOp(terminator);
986 }
987 
988 /// Pattern to remove operation with region that have constant false `ifCond`
989 /// and remove the condition from the operation if the `ifCond` is constant
990 /// true.
991 template <typename OpTy>
992 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
994 
995  LogicalResult matchAndRewrite(OpTy op,
996  PatternRewriter &rewriter) const override {
997  // Early return if there is no condition.
998  Value ifCond = op.getIfCond();
999  if (!ifCond)
1000  return failure();
1001 
1002  IntegerAttr constAttr;
1003  if (!matchPattern(ifCond, m_Constant(&constAttr)))
1004  return failure();
1005  if (constAttr.getInt())
1006  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1007  else
1008  replaceOpWithRegion(rewriter, op, op.getRegion());
1009 
1010  return success();
1011  }
1012 };
1013 
1014 //===----------------------------------------------------------------------===//
1015 // Recipe Region Helpers
1016 //===----------------------------------------------------------------------===//
1017 
1018 /// Create and populate an init region for privatization recipes.
1019 /// Returns the init block on success, or nullptr on failure.
1020 /// Sets needsFree to indicate if the allocated memory requires deallocation.
1021 static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
1022  Type varType, StringRef varName,
1023  ValueRange bounds,
1024  bool &needsFree) {
1025  // Create init block with arguments: original value + bounds
1026  SmallVector<Type> argTypes{varType};
1027  SmallVector<Location> argLocs{loc};
1028  for (Value bound : bounds) {
1029  argTypes.push_back(bound.getType());
1030  argLocs.push_back(loc);
1031  }
1032 
1033  auto initBlock = std::make_unique<Block>();
1034  initBlock->addArguments(argTypes, argLocs);
1035  builder.setInsertionPointToStart(initBlock.get());
1036 
1037  Value privatizedValue;
1038 
1039  // Get the block argument that represents the original variable
1040  Value blockArgVar = initBlock->getArgument(0);
1041 
1042  // Generate init region body based on variable type
1043  if (isa<MappableType>(varType)) {
1044  auto mappableTy = cast<MappableType>(varType);
1045  auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1046  privatizedValue = mappableTy.generatePrivateInit(
1047  builder, loc, typedVar, varName, bounds, {}, needsFree);
1048  if (!privatizedValue)
1049  return nullptr;
1050  } else {
1051  assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1052  auto pointerLikeTy = cast<PointerLikeType>(varType);
1053  // Use PointerLikeType's allocation API with the block argument
1054  privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1055  blockArgVar, needsFree);
1056  if (!privatizedValue)
1057  return nullptr;
1058  }
1059 
1060  // Add yield operation to init block
1061  acc::YieldOp::create(builder, loc, privatizedValue);
1062 
1063  return initBlock;
1064 }
1065 
1066 /// Create and populate a copy region for firstprivate recipes.
1067 /// Returns the copy block on success, or nullptr on failure.
1068 /// TODO: Handle MappableType - it does not yet have a copy API.
1069 static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
1070  Type varType,
1071  ValueRange bounds) {
1072  // Create copy block with arguments: original value + privatized value +
1073  // bounds
1074  SmallVector<Type> copyArgTypes{varType, varType};
1075  SmallVector<Location> copyArgLocs{loc, loc};
1076  for (Value bound : bounds) {
1077  copyArgTypes.push_back(bound.getType());
1078  copyArgLocs.push_back(loc);
1079  }
1080 
1081  auto copyBlock = std::make_unique<Block>();
1082  copyBlock->addArguments(copyArgTypes, copyArgLocs);
1083  builder.setInsertionPointToStart(copyBlock.get());
1084 
1085  bool isMappable = isa<MappableType>(varType);
1086  bool isPointerLike = isa<PointerLikeType>(varType);
1087  // TODO: Handle MappableType - it does not yet have a copy API.
1088  // Otherwise, for now just fallback to pointer-like behavior.
1089  if (isMappable && !isPointerLike)
1090  return nullptr;
1091 
1092  // Generate copy region body based on variable type
1093  if (isPointerLike) {
1094  auto pointerLikeTy = cast<PointerLikeType>(varType);
1095  Value originalArg = copyBlock->getArgument(0);
1096  Value privatizedArg = copyBlock->getArgument(1);
1097 
1098  // Generate copy operation using PointerLikeType interface
1099  if (!pointerLikeTy.genCopy(
1100  builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1101  cast<TypedValue<PointerLikeType>>(originalArg), varType))
1102  return nullptr;
1103  }
1104 
1105  // Add terminator to copy block
1106  acc::TerminatorOp::create(builder, loc);
1107 
1108  return copyBlock;
1109 }
1110 
1111 /// Create and populate a destroy region for privatization recipes.
1112 /// Returns the destroy block on success, or nullptr if not needed.
1113 static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
1114  Location loc, Type varType,
1115  Value allocRes,
1116  ValueRange bounds) {
1117  // Create destroy block with arguments: original value + privatized value +
1118  // bounds
1119  SmallVector<Type> destroyArgTypes{varType, varType};
1120  SmallVector<Location> destroyArgLocs{loc, loc};
1121  for (Value bound : bounds) {
1122  destroyArgTypes.push_back(bound.getType());
1123  destroyArgLocs.push_back(loc);
1124  }
1125 
1126  auto destroyBlock = std::make_unique<Block>();
1127  destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1128  builder.setInsertionPointToStart(destroyBlock.get());
1129 
1130  bool isMappable = isa<MappableType>(varType);
1131  bool isPointerLike = isa<PointerLikeType>(varType);
1132  // TODO: Handle MappableType - it does not yet have a deallocation API.
1133  // Otherwise, for now just fallback to pointer-like behavior.
1134  if (isMappable && !isPointerLike)
1135  return nullptr;
1136 
1137  assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1138  auto pointerLikeTy = cast<PointerLikeType>(varType);
1139  auto privatizedArg =
1140  cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1141  // Pass allocRes to help determine the allocation type
1142  if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
1143  return nullptr;
1144 
1145  acc::TerminatorOp::create(builder, loc);
1146 
1147  return destroyBlock;
1148 }
1149 
1150 } // namespace
1151 
1152 //===----------------------------------------------------------------------===//
1153 // PrivateRecipeOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 static LogicalResult verifyInitLikeSingleArgRegion(
1157  Operation *op, Region &region, StringRef regionType, StringRef regionName,
1158  Type type, bool verifyYield, bool optional = false) {
1159  if (optional && region.empty())
1160  return success();
1161 
1162  if (region.empty())
1163  return op->emitOpError() << "expects non-empty " << regionName << " region";
1164  Block &firstBlock = region.front();
1165  if (firstBlock.getNumArguments() < 1 ||
1166  firstBlock.getArgument(0).getType() != type)
1167  return op->emitOpError() << "expects " << regionName
1168  << " region first "
1169  "argument of the "
1170  << regionType << " type";
1171 
1172  if (verifyYield) {
1173  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1174  if (yieldOp.getOperands().size() != 1 ||
1175  yieldOp.getOperands().getTypes()[0] != type)
1176  return op->emitOpError() << "expects " << regionName
1177  << " region to "
1178  "yield a value of the "
1179  << regionType << " type";
1180  }
1181  }
1182  return success();
1183 }
1184 
1185 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1186  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1187  "privatization", "init", getType(),
1188  /*verifyYield=*/false)))
1189  return failure();
1191  *this, getDestroyRegion(), "privatization", "destroy", getType(),
1192  /*verifyYield=*/false, /*optional=*/true)))
1193  return failure();
1194  return success();
1195 }
1196 
1197 std::optional<PrivateRecipeOp>
1198 PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1199  StringRef recipeName, Type varType,
1200  StringRef varName, ValueRange bounds) {
1201  // First, validate that we can handle this variable type
1202  bool isMappable = isa<MappableType>(varType);
1203  bool isPointerLike = isa<PointerLikeType>(varType);
1204 
1205  // Unsupported type
1206  if (!isMappable && !isPointerLike)
1207  return std::nullopt;
1208 
1209  // Create init and destroy blocks using shared helpers
1210  OpBuilder::InsertionGuard guard(builder);
1211 
1212  // Save the original insertion point for creating the recipe operation later
1213  auto originalInsertionPoint = builder.saveInsertionPoint();
1214 
1215  bool needsFree = false;
1216  auto initBlock =
1217  createInitRegion(builder, loc, varType, varName, bounds, needsFree);
1218  if (!initBlock)
1219  return std::nullopt;
1220 
1221  // Only create destroy region if the allocation needs deallocation
1222  std::unique_ptr<Block> destroyBlock;
1223  if (needsFree) {
1224  // Extract the allocated value from the init block's yield operation
1225  auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
1226  Value allocRes = yieldOp.getOperand(0);
1227 
1228  destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
1229  if (!destroyBlock)
1230  return std::nullopt;
1231  }
1232 
1233  // Now create the recipe operation at the original insertion point and attach
1234  // the blocks
1235  builder.restoreInsertionPoint(originalInsertionPoint);
1236  auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1237 
1238  // Move the blocks into the recipe's regions
1239  recipe.getInitRegion().push_back(initBlock.release());
1240  if (destroyBlock)
1241  recipe.getDestroyRegion().push_back(destroyBlock.release());
1242 
1243  return recipe;
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // FirstprivateRecipeOp
1248 //===----------------------------------------------------------------------===//
1249 
1250 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1251  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1252  "privatization", "init", getType(),
1253  /*verifyYield=*/false)))
1254  return failure();
1255 
1256  if (getCopyRegion().empty())
1257  return emitOpError() << "expects non-empty copy region";
1258 
1259  Block &firstBlock = getCopyRegion().front();
1260  if (firstBlock.getNumArguments() < 2 ||
1261  firstBlock.getArgument(0).getType() != getType())
1262  return emitOpError() << "expects copy region with two arguments of the "
1263  "privatization type";
1264 
1265  if (getDestroyRegion().empty())
1266  return success();
1267 
1268  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1269  "privatization", "destroy",
1270  getType(), /*verifyYield=*/false)))
1271  return failure();
1272 
1273  return success();
1274 }
1275 
1276 std::optional<FirstprivateRecipeOp>
1277 FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1278  StringRef recipeName, Type varType,
1279  StringRef varName, ValueRange bounds) {
1280  // First, validate that we can handle this variable type
1281  bool isMappable = isa<MappableType>(varType);
1282  bool isPointerLike = isa<PointerLikeType>(varType);
1283 
1284  // Unsupported type
1285  if (!isMappable && !isPointerLike)
1286  return std::nullopt;
1287 
1288  // Create init, copy, and destroy blocks using shared helpers
1289  OpBuilder::InsertionGuard guard(builder);
1290 
1291  // Save the original insertion point for creating the recipe operation later
1292  auto originalInsertionPoint = builder.saveInsertionPoint();
1293 
1294  bool needsFree = false;
1295  auto initBlock =
1296  createInitRegion(builder, loc, varType, varName, bounds, needsFree);
1297  if (!initBlock)
1298  return std::nullopt;
1299 
1300  auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
1301  if (!copyBlock)
1302  return std::nullopt;
1303 
1304  // Only create destroy region if the allocation needs deallocation
1305  std::unique_ptr<Block> destroyBlock;
1306  if (needsFree) {
1307  // Extract the allocated value from the init block's yield operation
1308  auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
1309  Value allocRes = yieldOp.getOperand(0);
1310 
1311  destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
1312  if (!destroyBlock)
1313  return std::nullopt;
1314  }
1315 
1316  // Now create the recipe operation at the original insertion point and attach
1317  // the blocks
1318  builder.restoreInsertionPoint(originalInsertionPoint);
1319  auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1320 
1321  // Move the blocks into the recipe's regions
1322  recipe.getInitRegion().push_back(initBlock.release());
1323  recipe.getCopyRegion().push_back(copyBlock.release());
1324  if (destroyBlock)
1325  recipe.getDestroyRegion().push_back(destroyBlock.release());
1326 
1327  return recipe;
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // ReductionRecipeOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1335  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1336  "init", getType(),
1337  /*verifyYield=*/false)))
1338  return failure();
1339 
1340  if (getCombinerRegion().empty())
1341  return emitOpError() << "expects non-empty combiner region";
1342 
1343  Block &reductionBlock = getCombinerRegion().front();
1344  if (reductionBlock.getNumArguments() < 2 ||
1345  reductionBlock.getArgument(0).getType() != getType() ||
1346  reductionBlock.getArgument(1).getType() != getType())
1347  return emitOpError() << "expects combiner region with the first two "
1348  << "arguments of the reduction type";
1349 
1350  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1351  if (yieldOp.getOperands().size() != 1 ||
1352  yieldOp.getOperands().getTypes()[0] != getType())
1353  return emitOpError() << "expects combiner region to yield a value "
1354  "of the reduction type";
1355  }
1356 
1357  return success();
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // Custom parser and printer verifier for private clause
1362 //===----------------------------------------------------------------------===//
1363 
1364 static ParseResult parseSymOperandList(
1365  mlir::OpAsmParser &parser,
1367  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
1369  if (failed(parser.parseCommaSeparatedList([&]() {
1370  if (parser.parseAttribute(attributes.emplace_back()) ||
1371  parser.parseArrow() ||
1372  parser.parseOperand(operands.emplace_back()) ||
1373  parser.parseColonType(types.emplace_back()))
1374  return failure();
1375  return success();
1376  })))
1377  return failure();
1378  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1379  attributes.end());
1380  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
1381  return success();
1382 }
1383 
1385  mlir::OperandRange operands,
1386  mlir::TypeRange types,
1387  std::optional<mlir::ArrayAttr> attributes) {
1388  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
1389  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
1390  << std::get<1>(it).getType();
1391  });
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // ParallelOp
1396 //===----------------------------------------------------------------------===//
1397 
1398 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1399 template <typename Op>
1400 static LogicalResult checkDataOperands(Op op,
1401  const mlir::ValueRange &operands) {
1402  for (mlir::Value operand : operands)
1403  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1404  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1405  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1406  operand.getDefiningOp()))
1407  return op.emitError(
1408  "expect data entry/exit operation or acc.getdeviceptr "
1409  "as defining op");
1410  return success();
1411 }
1412 
1413 template <typename Op>
1414 static LogicalResult
1415 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
1416  mlir::OperandRange operands, llvm::StringRef operandName,
1417  llvm::StringRef symbolName, bool checkOperandType = true) {
1418  if (!operands.empty()) {
1419  if (!attributes || attributes->size() != operands.size())
1420  return op->emitOpError()
1421  << "expected as many " << symbolName << " symbol reference as "
1422  << operandName << " operands";
1423  } else {
1424  if (attributes)
1425  return op->emitOpError()
1426  << "unexpected " << symbolName << " symbol reference";
1427  return success();
1428  }
1429 
1431  for (auto args : llvm::zip(operands, *attributes)) {
1432  mlir::Value operand = std::get<0>(args);
1433 
1434  if (!set.insert(operand).second)
1435  return op->emitOpError()
1436  << operandName << " operand appears more than once";
1437 
1438  mlir::Type varType = operand.getType();
1439  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1440  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1441  if (!decl)
1442  return op->emitOpError()
1443  << "expected symbol reference " << symbolRef << " to point to a "
1444  << operandName << " declaration";
1445 
1446  if (checkOperandType && decl.getType() && decl.getType() != varType)
1447  return op->emitOpError() << "expected " << operandName << " (" << varType
1448  << ") to be the same type as " << operandName
1449  << " declaration (" << decl.getType() << ")";
1450  }
1451 
1452  return success();
1453 }
1454 
1455 unsigned ParallelOp::getNumDataOperands() {
1456  return getReductionOperands().size() + getPrivateOperands().size() +
1457  getFirstprivateOperands().size() + getDataClauseOperands().size();
1458 }
1459 
1460 Value ParallelOp::getDataOperand(unsigned i) {
1461  unsigned numOptional = getAsyncOperands().size();
1462  numOptional += getNumGangs().size();
1463  numOptional += getNumWorkers().size();
1464  numOptional += getVectorLength().size();
1465  numOptional += getIfCond() ? 1 : 0;
1466  numOptional += getSelfCond() ? 1 : 0;
1467  return getOperand(getWaitOperands().size() + numOptional + i);
1468 }
1469 
1470 template <typename Op>
1471 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1472  ArrayAttr deviceTypes,
1473  llvm::StringRef keyword) {
1474  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1475  return op.emitOpError() << keyword << " operands count must match "
1476  << keyword << " device_type count";
1477  return success();
1478 }
1479 
1480 template <typename Op>
1482  Op op, OperandRange operands, DenseI32ArrayAttr segments,
1483  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1484  std::size_t numOperandsInSegments = 0;
1485  std::size_t nbOfSegments = 0;
1486 
1487  if (segments) {
1488  for (auto segCount : segments.asArrayRef()) {
1489  if (maxInSegment != 0 && segCount > maxInSegment)
1490  return op.emitOpError() << keyword << " expects a maximum of "
1491  << maxInSegment << " values per segment";
1492  numOperandsInSegments += segCount;
1493  ++nbOfSegments;
1494  }
1495  }
1496 
1497  if ((numOperandsInSegments != operands.size()) ||
1498  (!deviceTypes && !operands.empty()))
1499  return op.emitOpError()
1500  << keyword << " operand count does not match count in segments";
1501  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1502  return op.emitOpError()
1503  << keyword << " segment count does not match device_type count";
1504  return success();
1505 }
1506 
1507 LogicalResult acc::ParallelOp::verify() {
1508  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1509  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1510  "privatizations", /*checkOperandType=*/false)))
1511  return failure();
1512  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1513  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1514  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1515  return failure();
1516  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1517  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1518  "reductions", false)))
1519  return failure();
1520 
1522  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1523  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1524  return failure();
1525 
1527  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1528  getWaitOperandsDeviceTypeAttr(), "wait")))
1529  return failure();
1530 
1531  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1532  getNumWorkersDeviceTypeAttr(),
1533  "num_workers")))
1534  return failure();
1535 
1536  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1537  getVectorLengthDeviceTypeAttr(),
1538  "vector_length")))
1539  return failure();
1540 
1542  getAsyncOperandsDeviceTypeAttr(),
1543  "async")))
1544  return failure();
1545 
1546  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1547  return failure();
1548 
1549  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1550 }
1551 
1552 static mlir::Value
1553 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1555  mlir::acc::DeviceType deviceType) {
1556  if (!arrayAttr)
1557  return {};
1558  if (auto pos = findSegment(*arrayAttr, deviceType))
1559  return range[*pos];
1560  return {};
1561 }
1562 
1563 bool acc::ParallelOp::hasAsyncOnly() {
1564  return hasAsyncOnly(mlir::acc::DeviceType::None);
1565 }
1566 
1567 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1568  return hasDeviceType(getAsyncOnly(), deviceType);
1569 }
1570 
1571 mlir::Value acc::ParallelOp::getAsyncValue() {
1572  return getAsyncValue(mlir::acc::DeviceType::None);
1573 }
1574 
1575 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1577  getAsyncOperands(), deviceType);
1578 }
1579 
1580 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1581  return getNumWorkersValue(mlir::acc::DeviceType::None);
1582 }
1583 
1585 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1586  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1587  deviceType);
1588 }
1589 
1590 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1591  return getVectorLengthValue(mlir::acc::DeviceType::None);
1592 }
1593 
1595 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1596  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1597  getVectorLength(), deviceType);
1598 }
1599 
1600 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1601  return getNumGangsValues(mlir::acc::DeviceType::None);
1602 }
1603 
1605 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1606  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1607  getNumGangsSegments(), deviceType);
1608 }
1609 
1610 bool acc::ParallelOp::hasWaitOnly() {
1611  return hasWaitOnly(mlir::acc::DeviceType::None);
1612 }
1613 
1614 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1615  return hasDeviceType(getWaitOnly(), deviceType);
1616 }
1617 
1618 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1619  return getWaitValues(mlir::acc::DeviceType::None);
1620 }
1621 
1623 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1625  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1626  getHasWaitDevnum(), deviceType);
1627 }
1628 
1629 mlir::Value ParallelOp::getWaitDevnum() {
1630  return getWaitDevnum(mlir::acc::DeviceType::None);
1631 }
1632 
1633 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1634  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1635  getWaitOperandsSegments(), getHasWaitDevnum(),
1636  deviceType);
1637 }
1638 
1639 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1640  mlir::OperationState &odsState,
1641  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1642  mlir::ValueRange vectorLength,
1643  mlir::ValueRange asyncOperands,
1644  mlir::ValueRange waitOperands, mlir::Value ifCond,
1645  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1646  mlir::ValueRange gangPrivateOperands,
1647  mlir::ValueRange gangFirstPrivateOperands,
1648  mlir::ValueRange dataClauseOperands) {
1649 
1650  ParallelOp::build(
1651  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1652  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1653  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1654  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1655  /*numGangsDeviceType=*/nullptr, numWorkers,
1656  /*numWorkersDeviceType=*/nullptr, vectorLength,
1657  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1658  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1659  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1660  /*firstprivatizations=*/nullptr, dataClauseOperands,
1661  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1662 }
1663 
1664 void acc::ParallelOp::addNumWorkersOperand(
1665  MLIRContext *context, mlir::Value newValue,
1666  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1667  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1668  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1669  getNumWorkersMutable()));
1670 }
1671 void acc::ParallelOp::addVectorLengthOperand(
1672  MLIRContext *context, mlir::Value newValue,
1673  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1674  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1675  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1676  getVectorLengthMutable()));
1677 }
1678 
1679 void acc::ParallelOp::addAsyncOnly(
1680  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1681  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1682  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1683 }
1684 
1685 void acc::ParallelOp::addAsyncOperand(
1686  MLIRContext *context, mlir::Value newValue,
1687  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1688  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1689  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1690  getAsyncOperandsMutable()));
1691 }
1692 
1693 void acc::ParallelOp::addNumGangsOperands(
1694  MLIRContext *context, mlir::ValueRange newValues,
1695  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1696  llvm::SmallVector<int32_t> segments;
1697  if (getNumGangsSegments())
1698  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1699 
1700  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1701  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1702  getNumGangsMutable(), segments));
1703 
1704  setNumGangsSegments(segments);
1705 }
1706 void acc::ParallelOp::addWaitOnly(
1707  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1708  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1709  effectiveDeviceTypes));
1710 }
1711 void acc::ParallelOp::addWaitOperands(
1712  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1713  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1714 
1715  llvm::SmallVector<int32_t> segments;
1716  if (getWaitOperandsSegments())
1717  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1718 
1719  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1720  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1721  getWaitOperandsMutable(), segments));
1722  setWaitOperandsSegments(segments);
1723 
1725  if (getHasWaitDevnumAttr())
1726  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1727  hasDevnums.insert(
1728  hasDevnums.end(),
1729  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1730  mlir::BoolAttr::get(context, hasDevnum));
1731  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1732 }
1733 
1734 void acc::ParallelOp::addPrivatization(MLIRContext *context,
1735  mlir::acc::PrivateOp op,
1736  mlir::acc::PrivateRecipeOp recipe) {
1737  getPrivateOperandsMutable().append(op.getResult());
1738 
1740 
1741  if (getPrivatizationRecipesAttr())
1742  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1743 
1744  recipes.push_back(
1745  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1746  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1747 }
1748 
1749 void acc::ParallelOp::addFirstPrivatization(
1750  MLIRContext *context, mlir::acc::FirstprivateOp op,
1751  mlir::acc::FirstprivateRecipeOp recipe) {
1752  getFirstprivateOperandsMutable().append(op.getResult());
1753 
1755 
1756  if (getFirstprivatizationRecipesAttr())
1757  llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1758 
1759  recipes.push_back(
1760  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1761  setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1762 }
1763 
1764 void acc::ParallelOp::addReduction(MLIRContext *context,
1765  mlir::acc::ReductionOp op,
1766  mlir::acc::ReductionRecipeOp recipe) {
1767  getReductionOperandsMutable().append(op.getResult());
1768 
1770 
1771  if (getReductionRecipesAttr())
1772  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1773 
1774  recipes.push_back(
1775  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1776  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1777 }
1778 
1779 static ParseResult parseNumGangs(
1780  mlir::OpAsmParser &parser,
1782  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1783  mlir::DenseI32ArrayAttr &segments) {
1786 
1787  do {
1788  if (failed(parser.parseLBrace()))
1789  return failure();
1790 
1791  int32_t crtOperandsSize = operands.size();
1792  if (failed(parser.parseCommaSeparatedList(
1794  if (parser.parseOperand(operands.emplace_back()) ||
1795  parser.parseColonType(types.emplace_back()))
1796  return failure();
1797  return success();
1798  })))
1799  return failure();
1800  seg.push_back(operands.size() - crtOperandsSize);
1801 
1802  if (failed(parser.parseRBrace()))
1803  return failure();
1804 
1805  if (succeeded(parser.parseOptionalLSquare())) {
1806  if (parser.parseAttribute(attributes.emplace_back()) ||
1807  parser.parseRSquare())
1808  return failure();
1809  } else {
1810  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1812  }
1813  } while (succeeded(parser.parseOptionalComma()));
1814 
1815  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1816  attributes.end());
1817  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1818  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1819 
1820  return success();
1821 }
1822 
1824  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1825  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1826  p << " [" << attr << "]";
1827 }
1828 
1830  mlir::OperandRange operands, mlir::TypeRange types,
1831  std::optional<mlir::ArrayAttr> deviceTypes,
1832  std::optional<mlir::DenseI32ArrayAttr> segments) {
1833  unsigned opIdx = 0;
1834  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1835  p << "{";
1836  llvm::interleaveComma(
1837  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1838  p << operands[opIdx] << " : " << operands[opIdx].getType();
1839  ++opIdx;
1840  });
1841  p << "}";
1842  printSingleDeviceType(p, it.value());
1843  });
1844 }
1845 
1847  mlir::OpAsmParser &parser,
1849  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1850  mlir::DenseI32ArrayAttr &segments) {
1853 
1854  do {
1855  if (failed(parser.parseLBrace()))
1856  return failure();
1857 
1858  int32_t crtOperandsSize = operands.size();
1859 
1860  if (failed(parser.parseCommaSeparatedList(
1862  if (parser.parseOperand(operands.emplace_back()) ||
1863  parser.parseColonType(types.emplace_back()))
1864  return failure();
1865  return success();
1866  })))
1867  return failure();
1868 
1869  seg.push_back(operands.size() - crtOperandsSize);
1870 
1871  if (failed(parser.parseRBrace()))
1872  return failure();
1873 
1874  if (succeeded(parser.parseOptionalLSquare())) {
1875  if (parser.parseAttribute(attributes.emplace_back()) ||
1876  parser.parseRSquare())
1877  return failure();
1878  } else {
1879  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1881  }
1882  } while (succeeded(parser.parseOptionalComma()));
1883 
1884  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1885  attributes.end());
1886  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1887  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1888 
1889  return success();
1890 }
1891 
1894  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1895  std::optional<mlir::DenseI32ArrayAttr> segments) {
1896  unsigned opIdx = 0;
1897  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1898  p << "{";
1899  llvm::interleaveComma(
1900  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1901  p << operands[opIdx] << " : " << operands[opIdx].getType();
1902  ++opIdx;
1903  });
1904  p << "}";
1905  printSingleDeviceType(p, it.value());
1906  });
1907 }
1908 
1909 static ParseResult parseWaitClause(
1910  mlir::OpAsmParser &parser,
1912  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1913  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1914  mlir::ArrayAttr &keywordOnly) {
1915  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1917 
1918  bool needCommaBeforeOperands = false;
1919 
1920  // Keyword only
1921  if (failed(parser.parseOptionalLParen())) {
1922  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1924  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1925  return success();
1926  }
1927 
1928  // Parse keyword only attributes
1929  if (succeeded(parser.parseOptionalLSquare())) {
1930  if (failed(parser.parseCommaSeparatedList([&]() {
1931  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1932  return failure();
1933  return success();
1934  })))
1935  return failure();
1936  if (parser.parseRSquare())
1937  return failure();
1938  needCommaBeforeOperands = true;
1939  }
1940 
1941  if (needCommaBeforeOperands && failed(parser.parseComma()))
1942  return failure();
1943 
1944  do {
1945  if (failed(parser.parseLBrace()))
1946  return failure();
1947 
1948  int32_t crtOperandsSize = operands.size();
1949 
1950  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1951  if (failed(parser.parseColon()))
1952  return failure();
1953  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1954  } else {
1955  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1956  }
1957 
1958  if (failed(parser.parseCommaSeparatedList(
1960  if (parser.parseOperand(operands.emplace_back()) ||
1961  parser.parseColonType(types.emplace_back()))
1962  return failure();
1963  return success();
1964  })))
1965  return failure();
1966 
1967  seg.push_back(operands.size() - crtOperandsSize);
1968 
1969  if (failed(parser.parseRBrace()))
1970  return failure();
1971 
1972  if (succeeded(parser.parseOptionalLSquare())) {
1973  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1974  parser.parseRSquare())
1975  return failure();
1976  } else {
1977  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1979  }
1980  } while (succeeded(parser.parseOptionalComma()));
1981 
1982  if (failed(parser.parseRParen()))
1983  return failure();
1984 
1985  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1986  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1987  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1988  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1989 
1990  return success();
1991 }
1992 
1993 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1994  if (!hasDeviceTypeValues(attrs))
1995  return false;
1996  if (attrs->size() != 1)
1997  return false;
1998  if (auto deviceTypeAttr =
1999  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2000  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2001  return false;
2002 }
2003 
2005  mlir::OperandRange operands, mlir::TypeRange types,
2006  std::optional<mlir::ArrayAttr> deviceTypes,
2007  std::optional<mlir::DenseI32ArrayAttr> segments,
2008  std::optional<mlir::ArrayAttr> hasDevNum,
2009  std::optional<mlir::ArrayAttr> keywordOnly) {
2010 
2011  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2012  return;
2013 
2014  p << "(";
2015 
2016  printDeviceTypes(p, keywordOnly);
2017  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2018  p << ", ";
2019 
2020  if (hasDeviceTypeValues(deviceTypes)) {
2021  unsigned opIdx = 0;
2022  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2023  p << "{";
2024  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2025  if (boolAttr && boolAttr.getValue())
2026  p << "devnum: ";
2027  llvm::interleaveComma(
2028  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2029  p << operands[opIdx] << " : " << operands[opIdx].getType();
2030  ++opIdx;
2031  });
2032  p << "}";
2033  printSingleDeviceType(p, it.value());
2034  });
2035  }
2036 
2037  p << ")";
2038 }
2039 
2040 static ParseResult parseDeviceTypeOperands(
2041  mlir::OpAsmParser &parser,
2043  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2045  if (failed(parser.parseCommaSeparatedList([&]() {
2046  if (parser.parseOperand(operands.emplace_back()) ||
2047  parser.parseColonType(types.emplace_back()))
2048  return failure();
2049  if (succeeded(parser.parseOptionalLSquare())) {
2050  if (parser.parseAttribute(attributes.emplace_back()) ||
2051  parser.parseRSquare())
2052  return failure();
2053  } else {
2054  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2055  parser.getContext(), mlir::acc::DeviceType::None));
2056  }
2057  return success();
2058  })))
2059  return failure();
2060  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2061  attributes.end());
2062  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2063  return success();
2064 }
2065 
2066 static void
2068  mlir::OperandRange operands, mlir::TypeRange types,
2069  std::optional<mlir::ArrayAttr> deviceTypes) {
2070  if (!hasDeviceTypeValues(deviceTypes))
2071  return;
2072  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2073  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2074  printSingleDeviceType(p, std::get<0>(it));
2075  });
2076 }
2077 
2079  mlir::OpAsmParser &parser,
2081  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2082  mlir::ArrayAttr &keywordOnlyDeviceType) {
2083 
2084  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2085  bool needCommaBeforeOperands = false;
2086 
2087  if (failed(parser.parseOptionalLParen())) {
2088  // Keyword only
2089  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2091  keywordOnlyDeviceType =
2092  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2093  return success();
2094  }
2095 
2096  // Parse keyword only attributes
2097  if (succeeded(parser.parseOptionalLSquare())) {
2098  // Parse keyword only attributes
2099  if (failed(parser.parseCommaSeparatedList([&]() {
2100  if (parser.parseAttribute(
2101  keywordOnlyDeviceTypeAttributes.emplace_back()))
2102  return failure();
2103  return success();
2104  })))
2105  return failure();
2106  if (parser.parseRSquare())
2107  return failure();
2108  needCommaBeforeOperands = true;
2109  }
2110 
2111  if (needCommaBeforeOperands && failed(parser.parseComma()))
2112  return failure();
2113 
2115  if (failed(parser.parseCommaSeparatedList([&]() {
2116  if (parser.parseOperand(operands.emplace_back()) ||
2117  parser.parseColonType(types.emplace_back()))
2118  return failure();
2119  if (succeeded(parser.parseOptionalLSquare())) {
2120  if (parser.parseAttribute(attributes.emplace_back()) ||
2121  parser.parseRSquare())
2122  return failure();
2123  } else {
2124  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2125  parser.getContext(), mlir::acc::DeviceType::None));
2126  }
2127  return success();
2128  })))
2129  return failure();
2130 
2131  if (failed(parser.parseRParen()))
2132  return failure();
2133 
2134  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2135  attributes.end());
2136  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2137  return success();
2138 }
2139 
2142  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2143  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2144 
2145  if (operands.begin() == operands.end() &&
2146  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2147  return;
2148  }
2149 
2150  p << "(";
2151  printDeviceTypes(p, keywordOnlyDeviceTypes);
2152  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2153  hasDeviceTypeValues(deviceTypes))
2154  p << ", ";
2155  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2156  p << ")";
2157 }
2158 
2159 static ParseResult parseOperandWithKeywordOnly(
2160  mlir::OpAsmParser &parser,
2161  std::optional<OpAsmParser::UnresolvedOperand> &operand,
2162  mlir::Type &operandType, mlir::UnitAttr &attr) {
2163  // Keyword only
2164  if (failed(parser.parseOptionalLParen())) {
2165  attr = mlir::UnitAttr::get(parser.getContext());
2166  return success();
2167  }
2168 
2170  if (failed(parser.parseOperand(op)))
2171  return failure();
2172  operand = op;
2173  if (failed(parser.parseColon()))
2174  return failure();
2175  if (failed(parser.parseType(operandType)))
2176  return failure();
2177  if (failed(parser.parseRParen()))
2178  return failure();
2179 
2180  return success();
2181 }
2182 
2184  mlir::Operation *op,
2185  std::optional<mlir::Value> operand,
2186  mlir::Type operandType,
2187  mlir::UnitAttr attr) {
2188  if (attr)
2189  return;
2190 
2191  p << "(";
2192  p.printOperand(*operand);
2193  p << " : ";
2194  p.printType(operandType);
2195  p << ")";
2196 }
2197 
2198 static ParseResult parseOperandsWithKeywordOnly(
2199  mlir::OpAsmParser &parser,
2201  llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2202  // Keyword only
2203  if (failed(parser.parseOptionalLParen())) {
2204  attr = mlir::UnitAttr::get(parser.getContext());
2205  return success();
2206  }
2207 
2208  if (failed(parser.parseCommaSeparatedList([&]() {
2209  if (parser.parseOperand(operands.emplace_back()))
2210  return failure();
2211  return success();
2212  })))
2213  return failure();
2214  if (failed(parser.parseColon()))
2215  return failure();
2216  if (failed(parser.parseCommaSeparatedList([&]() {
2217  if (parser.parseType(types.emplace_back()))
2218  return failure();
2219  return success();
2220  })))
2221  return failure();
2222  if (failed(parser.parseRParen()))
2223  return failure();
2224 
2225  return success();
2226 }
2227 
2229  mlir::Operation *op,
2230  mlir::OperandRange operands,
2231  mlir::TypeRange types,
2232  mlir::UnitAttr attr) {
2233  if (attr)
2234  return;
2235 
2236  p << "(";
2237  llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2238  p << " : ";
2239  llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2240  p << ")";
2241 }
2242 
2243 static ParseResult
2245  mlir::acc::CombinedConstructsTypeAttr &attr) {
2246  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2248  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2249  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2251  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2252  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2254  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2255  } else {
2256  parser.emitError(parser.getCurrentLocation(),
2257  "expected compute construct name");
2258  return failure();
2259  }
2260  return success();
2261 }
2262 
2263 static void
2265  mlir::acc::CombinedConstructsTypeAttr attr) {
2266  if (attr) {
2267  switch (attr.getValue()) {
2268  case mlir::acc::CombinedConstructsType::KernelsLoop:
2269  p << "kernels";
2270  break;
2271  case mlir::acc::CombinedConstructsType::ParallelLoop:
2272  p << "parallel";
2273  break;
2274  case mlir::acc::CombinedConstructsType::SerialLoop:
2275  p << "serial";
2276  break;
2277  };
2278  }
2279 }
2280 
2281 //===----------------------------------------------------------------------===//
2282 // SerialOp
2283 //===----------------------------------------------------------------------===//
2284 
2285 unsigned SerialOp::getNumDataOperands() {
2286  return getReductionOperands().size() + getPrivateOperands().size() +
2287  getFirstprivateOperands().size() + getDataClauseOperands().size();
2288 }
2289 
2290 Value SerialOp::getDataOperand(unsigned i) {
2291  unsigned numOptional = getAsyncOperands().size();
2292  numOptional += getIfCond() ? 1 : 0;
2293  numOptional += getSelfCond() ? 1 : 0;
2294  return getOperand(getWaitOperands().size() + numOptional + i);
2295 }
2296 
2297 bool acc::SerialOp::hasAsyncOnly() {
2298  return hasAsyncOnly(mlir::acc::DeviceType::None);
2299 }
2300 
2301 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2302  return hasDeviceType(getAsyncOnly(), deviceType);
2303 }
2304 
2305 mlir::Value acc::SerialOp::getAsyncValue() {
2306  return getAsyncValue(mlir::acc::DeviceType::None);
2307 }
2308 
2309 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2311  getAsyncOperands(), deviceType);
2312 }
2313 
2314 bool acc::SerialOp::hasWaitOnly() {
2315  return hasWaitOnly(mlir::acc::DeviceType::None);
2316 }
2317 
2318 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2319  return hasDeviceType(getWaitOnly(), deviceType);
2320 }
2321 
2322 mlir::Operation::operand_range SerialOp::getWaitValues() {
2323  return getWaitValues(mlir::acc::DeviceType::None);
2324 }
2325 
2327 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2329  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2330  getHasWaitDevnum(), deviceType);
2331 }
2332 
2333 mlir::Value SerialOp::getWaitDevnum() {
2334  return getWaitDevnum(mlir::acc::DeviceType::None);
2335 }
2336 
2337 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2338  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2339  getWaitOperandsSegments(), getHasWaitDevnum(),
2340  deviceType);
2341 }
2342 
2343 LogicalResult acc::SerialOp::verify() {
2344  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2345  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2346  "privatizations", /*checkOperandType=*/false)))
2347  return failure();
2348  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
2349  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2350  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
2351  return failure();
2352  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2353  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2354  "reductions", false)))
2355  return failure();
2356 
2358  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2359  getWaitOperandsDeviceTypeAttr(), "wait")))
2360  return failure();
2361 
2363  getAsyncOperandsDeviceTypeAttr(),
2364  "async")))
2365  return failure();
2366 
2367  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
2368  return failure();
2369 
2370  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2371 }
2372 
2373 void acc::SerialOp::addAsyncOnly(
2374  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2375  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2376  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2377 }
2378 
2379 void acc::SerialOp::addAsyncOperand(
2380  MLIRContext *context, mlir::Value newValue,
2381  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2382  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2383  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2384  getAsyncOperandsMutable()));
2385 }
2386 
2387 void acc::SerialOp::addWaitOnly(
2388  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2389  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2390  effectiveDeviceTypes));
2391 }
2392 void acc::SerialOp::addWaitOperands(
2393  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2394  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2395 
2396  llvm::SmallVector<int32_t> segments;
2397  if (getWaitOperandsSegments())
2398  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2399 
2400  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2401  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2402  getWaitOperandsMutable(), segments));
2403  setWaitOperandsSegments(segments);
2404 
2406  if (getHasWaitDevnumAttr())
2407  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2408  hasDevnums.insert(
2409  hasDevnums.end(),
2410  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2411  mlir::BoolAttr::get(context, hasDevnum));
2412  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2413 }
2414 
2415 void acc::SerialOp::addPrivatization(MLIRContext *context,
2416  mlir::acc::PrivateOp op,
2417  mlir::acc::PrivateRecipeOp recipe) {
2418  getPrivateOperandsMutable().append(op.getResult());
2419 
2421 
2422  if (getPrivatizationRecipesAttr())
2423  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2424 
2425  recipes.push_back(
2426  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2427  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2428 }
2429 
2430 void acc::SerialOp::addFirstPrivatization(
2431  MLIRContext *context, mlir::acc::FirstprivateOp op,
2432  mlir::acc::FirstprivateRecipeOp recipe) {
2433  getFirstprivateOperandsMutable().append(op.getResult());
2434 
2436 
2437  if (getFirstprivatizationRecipesAttr())
2438  llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2439 
2440  recipes.push_back(
2441  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2442  setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2443 }
2444 
2445 void acc::SerialOp::addReduction(MLIRContext *context,
2446  mlir::acc::ReductionOp op,
2447  mlir::acc::ReductionRecipeOp recipe) {
2448  getReductionOperandsMutable().append(op.getResult());
2449 
2451 
2452  if (getReductionRecipesAttr())
2453  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2454 
2455  recipes.push_back(
2456  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2457  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // KernelsOp
2462 //===----------------------------------------------------------------------===//
2463 
2464 unsigned KernelsOp::getNumDataOperands() {
2465  return getDataClauseOperands().size();
2466 }
2467 
2468 Value KernelsOp::getDataOperand(unsigned i) {
2469  unsigned numOptional = getAsyncOperands().size();
2470  numOptional += getWaitOperands().size();
2471  numOptional += getNumGangs().size();
2472  numOptional += getNumWorkers().size();
2473  numOptional += getVectorLength().size();
2474  numOptional += getIfCond() ? 1 : 0;
2475  numOptional += getSelfCond() ? 1 : 0;
2476  return getOperand(numOptional + i);
2477 }
2478 
2479 bool acc::KernelsOp::hasAsyncOnly() {
2480  return hasAsyncOnly(mlir::acc::DeviceType::None);
2481 }
2482 
2483 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2484  return hasDeviceType(getAsyncOnly(), deviceType);
2485 }
2486 
2487 mlir::Value acc::KernelsOp::getAsyncValue() {
2488  return getAsyncValue(mlir::acc::DeviceType::None);
2489 }
2490 
2491 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2493  getAsyncOperands(), deviceType);
2494 }
2495 
2496 mlir::Value acc::KernelsOp::getNumWorkersValue() {
2497  return getNumWorkersValue(mlir::acc::DeviceType::None);
2498 }
2499 
2501 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2502  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2503  deviceType);
2504 }
2505 
2506 mlir::Value acc::KernelsOp::getVectorLengthValue() {
2507  return getVectorLengthValue(mlir::acc::DeviceType::None);
2508 }
2509 
2511 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2512  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2513  getVectorLength(), deviceType);
2514 }
2515 
2516 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2517  return getNumGangsValues(mlir::acc::DeviceType::None);
2518 }
2519 
2521 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2522  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2523  getNumGangsSegments(), deviceType);
2524 }
2525 
2526 bool acc::KernelsOp::hasWaitOnly() {
2527  return hasWaitOnly(mlir::acc::DeviceType::None);
2528 }
2529 
2530 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2531  return hasDeviceType(getWaitOnly(), deviceType);
2532 }
2533 
2534 mlir::Operation::operand_range KernelsOp::getWaitValues() {
2535  return getWaitValues(mlir::acc::DeviceType::None);
2536 }
2537 
2539 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2541  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2542  getHasWaitDevnum(), deviceType);
2543 }
2544 
2545 mlir::Value KernelsOp::getWaitDevnum() {
2546  return getWaitDevnum(mlir::acc::DeviceType::None);
2547 }
2548 
2549 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2550  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2551  getWaitOperandsSegments(), getHasWaitDevnum(),
2552  deviceType);
2553 }
2554 
2555 LogicalResult acc::KernelsOp::verify() {
2557  *this, getNumGangs(), getNumGangsSegmentsAttr(),
2558  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2559  return failure();
2560 
2562  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2563  getWaitOperandsDeviceTypeAttr(), "wait")))
2564  return failure();
2565 
2566  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2567  getNumWorkersDeviceTypeAttr(),
2568  "num_workers")))
2569  return failure();
2570 
2571  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2572  getVectorLengthDeviceTypeAttr(),
2573  "vector_length")))
2574  return failure();
2575 
2577  getAsyncOperandsDeviceTypeAttr(),
2578  "async")))
2579  return failure();
2580 
2581  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
2582  return failure();
2583 
2584  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2585 }
2586 
2587 void acc::KernelsOp::addNumWorkersOperand(
2588  MLIRContext *context, mlir::Value newValue,
2589  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2590  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2591  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2592  getNumWorkersMutable()));
2593 }
2594 
2595 void acc::KernelsOp::addVectorLengthOperand(
2596  MLIRContext *context, mlir::Value newValue,
2597  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2598  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2599  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2600  getVectorLengthMutable()));
2601 }
2602 void acc::KernelsOp::addAsyncOnly(
2603  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2604  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2605  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2606 }
2607 
2608 void acc::KernelsOp::addAsyncOperand(
2609  MLIRContext *context, mlir::Value newValue,
2610  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2611  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2612  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2613  getAsyncOperandsMutable()));
2614 }
2615 
2616 void acc::KernelsOp::addNumGangsOperands(
2617  MLIRContext *context, mlir::ValueRange newValues,
2618  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2619  llvm::SmallVector<int32_t> segments;
2620  if (getNumGangsSegmentsAttr())
2621  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2622 
2623  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2624  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2625  getNumGangsMutable(), segments));
2626 
2627  setNumGangsSegments(segments);
2628 }
2629 
2630 void acc::KernelsOp::addWaitOnly(
2631  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2632  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2633  effectiveDeviceTypes));
2634 }
2635 void acc::KernelsOp::addWaitOperands(
2636  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2637  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2638 
2639  llvm::SmallVector<int32_t> segments;
2640  if (getWaitOperandsSegments())
2641  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2642 
2643  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2644  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2645  getWaitOperandsMutable(), segments));
2646  setWaitOperandsSegments(segments);
2647 
2649  if (getHasWaitDevnumAttr())
2650  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2651  hasDevnums.insert(
2652  hasDevnums.end(),
2653  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2654  mlir::BoolAttr::get(context, hasDevnum));
2655  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2656 }
2657 
2658 //===----------------------------------------------------------------------===//
2659 // HostDataOp
2660 //===----------------------------------------------------------------------===//
2661 
2662 LogicalResult acc::HostDataOp::verify() {
2663  if (getDataClauseOperands().empty())
2664  return emitError("at least one operand must appear on the host_data "
2665  "operation");
2666 
2667  for (mlir::Value operand : getDataClauseOperands())
2668  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2669  return emitError("expect data entry operation as defining op");
2670  return success();
2671 }
2672 
2673 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2674  MLIRContext *context) {
2675  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2676 }
2677 
2678 //===----------------------------------------------------------------------===//
2679 // LoopOp
2680 //===----------------------------------------------------------------------===//
2681 
2682 static ParseResult parseGangValue(
2683  OpAsmParser &parser, llvm::StringRef keyword,
2686  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2687  bool &needCommaBetweenValues, bool &newValue) {
2688  if (succeeded(parser.parseOptionalKeyword(keyword))) {
2689  if (parser.parseEqual())
2690  return failure();
2691  if (parser.parseOperand(operands.emplace_back()) ||
2692  parser.parseColonType(types.emplace_back()))
2693  return failure();
2694  attributes.push_back(gangArgType);
2695  needCommaBetweenValues = true;
2696  newValue = true;
2697  }
2698  return success();
2699 }
2700 
2701 static ParseResult parseGangClause(
2702  OpAsmParser &parser,
2704  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2705  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2706  mlir::ArrayAttr &gangOnlyDeviceType) {
2707  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2708  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2709  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2711  bool needCommaBetweenValues = false;
2712  bool needCommaBeforeOperands = false;
2713 
2714  if (failed(parser.parseOptionalLParen())) {
2715  // Gang only keyword
2716  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2718  gangOnlyDeviceType =
2719  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2720  return success();
2721  }
2722 
2723  // Parse gang only attributes
2724  if (succeeded(parser.parseOptionalLSquare())) {
2725  // Parse gang only attributes
2726  if (failed(parser.parseCommaSeparatedList([&]() {
2727  if (parser.parseAttribute(
2728  gangOnlyDeviceTypeAttributes.emplace_back()))
2729  return failure();
2730  return success();
2731  })))
2732  return failure();
2733  if (parser.parseRSquare())
2734  return failure();
2735  needCommaBeforeOperands = true;
2736  }
2737 
2738  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2739  mlir::acc::GangArgType::Num);
2740  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2741  mlir::acc::GangArgType::Dim);
2742  auto argStatic = mlir::acc::GangArgTypeAttr::get(
2743  parser.getContext(), mlir::acc::GangArgType::Static);
2744 
2745  do {
2746  if (needCommaBeforeOperands) {
2747  needCommaBeforeOperands = false;
2748  continue;
2749  }
2750 
2751  if (failed(parser.parseLBrace()))
2752  return failure();
2753 
2754  int32_t crtOperandsSize = gangOperands.size();
2755  while (true) {
2756  bool newValue = false;
2757  bool needValue = false;
2758  if (needCommaBetweenValues) {
2759  if (succeeded(parser.parseOptionalComma()))
2760  needValue = true; // expect a new value after comma.
2761  else
2762  break;
2763  }
2764 
2765  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2766  gangOperands, gangOperandsType,
2767  gangArgTypeAttributes, argNum,
2768  needCommaBetweenValues, newValue)))
2769  return failure();
2770  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2771  gangOperands, gangOperandsType,
2772  gangArgTypeAttributes, argDim,
2773  needCommaBetweenValues, newValue)))
2774  return failure();
2775  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2776  gangOperands, gangOperandsType,
2777  gangArgTypeAttributes, argStatic,
2778  needCommaBetweenValues, newValue)))
2779  return failure();
2780 
2781  if (!newValue && needValue) {
2782  parser.emitError(parser.getCurrentLocation(),
2783  "new value expected after comma");
2784  return failure();
2785  }
2786 
2787  if (!newValue)
2788  break;
2789  }
2790 
2791  if (gangOperands.empty())
2792  return parser.emitError(
2793  parser.getCurrentLocation(),
2794  "expect at least one of num, dim or static values");
2795 
2796  if (failed(parser.parseRBrace()))
2797  return failure();
2798 
2799  if (succeeded(parser.parseOptionalLSquare())) {
2800  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2801  parser.parseRSquare())
2802  return failure();
2803  } else {
2804  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2806  }
2807 
2808  seg.push_back(gangOperands.size() - crtOperandsSize);
2809 
2810  } while (succeeded(parser.parseOptionalComma()));
2811 
2812  if (failed(parser.parseRParen()))
2813  return failure();
2814 
2815  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2816  gangArgTypeAttributes.end());
2817  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2818  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2819 
2821  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2822  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2823 
2824  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2825  return success();
2826 }
2827 
2829  mlir::OperandRange operands, mlir::TypeRange types,
2830  std::optional<mlir::ArrayAttr> gangArgTypes,
2831  std::optional<mlir::ArrayAttr> deviceTypes,
2832  std::optional<mlir::DenseI32ArrayAttr> segments,
2833  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2834 
2835  if (operands.begin() == operands.end() &&
2836  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2837  return;
2838  }
2839 
2840  p << "(";
2841 
2842  printDeviceTypes(p, gangOnlyDeviceTypes);
2843 
2844  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2845  hasDeviceTypeValues(deviceTypes))
2846  p << ", ";
2847 
2848  if (hasDeviceTypeValues(deviceTypes)) {
2849  unsigned opIdx = 0;
2850  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2851  p << "{";
2852  llvm::interleaveComma(
2853  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2854  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2855  (*gangArgTypes)[opIdx]);
2856  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2857  p << LoopOp::getGangNumKeyword();
2858  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2859  p << LoopOp::getGangDimKeyword();
2860  else if (gangArgTypeAttr.getValue() ==
2861  mlir::acc::GangArgType::Static)
2862  p << LoopOp::getGangStaticKeyword();
2863  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2864  ++opIdx;
2865  });
2866  p << "}";
2867  printSingleDeviceType(p, it.value());
2868  });
2869  }
2870  p << ")";
2871 }
2872 
2874  std::optional<mlir::ArrayAttr> segments,
2875  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2876  if (!segments)
2877  return false;
2878  for (auto attr : *segments) {
2879  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2880  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2881  return true;
2882  }
2883  return false;
2884 }
2885 
2886 /// Check for duplicates in the DeviceType array attribute.
2887 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2888  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2889  if (!deviceTypes)
2890  return success();
2891  for (auto attr : deviceTypes) {
2892  auto deviceTypeAttr =
2893  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2894  if (!deviceTypeAttr)
2895  return failure();
2896  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2897  return failure();
2898  }
2899  return success();
2900 }
2901 
2902 LogicalResult acc::LoopOp::verify() {
2903  if (getUpperbound().size() != getStep().size())
2904  return emitError() << "number of upperbounds expected to be the same as "
2905  "number of steps";
2906 
2907  if (getUpperbound().size() != getLowerbound().size())
2908  return emitError() << "number of upperbounds expected to be the same as "
2909  "number of lowerbounds";
2910 
2911  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2912  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2913  return emitError() << "inclusiveUpperbound size is expected to be the same"
2914  << " as upperbound size";
2915 
2916  // Check collapse
2917  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2918  return emitOpError() << "collapse device_type attr must be define when"
2919  << " collapse attr is present";
2920 
2921  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2922  getCollapseAttr().getValue().size() !=
2923  getCollapseDeviceTypeAttr().getValue().size())
2924  return emitOpError() << "collapse attribute count must match collapse"
2925  << " device_type count";
2926  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2927  return emitOpError()
2928  << "duplicate device_type found in collapseDeviceType attribute";
2929 
2930  // Check gang
2931  if (!getGangOperands().empty()) {
2932  if (!getGangOperandsArgType())
2933  return emitOpError() << "gangOperandsArgType attribute must be defined"
2934  << " when gang operands are present";
2935 
2936  if (getGangOperands().size() !=
2937  getGangOperandsArgTypeAttr().getValue().size())
2938  return emitOpError() << "gangOperandsArgType attribute count must match"
2939  << " gangOperands count";
2940  }
2941  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2942  return emitOpError() << "duplicate device_type found in gang attribute";
2943 
2945  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2946  getGangOperandsDeviceTypeAttr(), "gang")))
2947  return failure();
2948 
2949  // Check worker
2950  if (failed(checkDeviceTypes(getWorkerAttr())))
2951  return emitOpError() << "duplicate device_type found in worker attribute";
2952  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2953  return emitOpError() << "duplicate device_type found in "
2954  "workerNumOperandsDeviceType attribute";
2955  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2956  getWorkerNumOperandsDeviceTypeAttr(),
2957  "worker")))
2958  return failure();
2959 
2960  // Check vector
2961  if (failed(checkDeviceTypes(getVectorAttr())))
2962  return emitOpError() << "duplicate device_type found in vector attribute";
2963  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2964  return emitOpError() << "duplicate device_type found in "
2965  "vectorOperandsDeviceType attribute";
2966  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2967  getVectorOperandsDeviceTypeAttr(),
2968  "vector")))
2969  return failure();
2970 
2972  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2973  getTileOperandsDeviceTypeAttr(), "tile")))
2974  return failure();
2975 
2976  // auto, independent and seq attribute are mutually exclusive.
2977  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2978  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2979  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2980  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2981  return emitError() << "only one of auto, independent, seq can be present "
2982  "at the same time";
2983  }
2984 
2985  // Check that at least one of auto, independent, or seq is present
2986  // for the device-independent default clauses.
2987  auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
2988  return attr.getValue() == mlir::acc::DeviceType::None;
2989  };
2990  bool hasDefaultSeq =
2991  getSeqAttr()
2992  ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2993  hasDeviceNone)
2994  : false;
2995  bool hasDefaultIndependent =
2996  getIndependentAttr()
2997  ? llvm::any_of(
2998  getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2999  hasDeviceNone)
3000  : false;
3001  bool hasDefaultAuto =
3002  getAuto_Attr()
3003  ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3004  hasDeviceNone)
3005  : false;
3006  if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3007  return emitError()
3008  << "at least one of auto, independent, seq must be present";
3009  }
3010 
3011  // Gang, worker and vector are incompatible with seq.
3012  if (getSeqAttr()) {
3013  for (auto attr : getSeqAttr()) {
3014  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3015  if (hasVector(deviceTypeAttr.getValue()) ||
3016  getVectorValue(deviceTypeAttr.getValue()) ||
3017  hasWorker(deviceTypeAttr.getValue()) ||
3018  getWorkerValue(deviceTypeAttr.getValue()) ||
3019  hasGang(deviceTypeAttr.getValue()) ||
3020  getGangValue(mlir::acc::GangArgType::Num,
3021  deviceTypeAttr.getValue()) ||
3022  getGangValue(mlir::acc::GangArgType::Dim,
3023  deviceTypeAttr.getValue()) ||
3024  getGangValue(mlir::acc::GangArgType::Static,
3025  deviceTypeAttr.getValue()))
3026  return emitError() << "gang, worker or vector cannot appear with seq";
3027  }
3028  }
3029 
3030  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
3031  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
3032  "privatizations", false)))
3033  return failure();
3034 
3035  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
3036  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3037  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
3038  return failure();
3039 
3040  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
3041  *this, getReductionRecipes(), getReductionOperands(), "reduction",
3042  "reductions", false)))
3043  return failure();
3044 
3045  if (getCombined().has_value() &&
3046  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3047  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3048  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3049  return emitError("unexpected combined constructs attribute");
3050  }
3051 
3052  // Check non-empty body().
3053  if (getRegion().empty())
3054  return emitError("expected non-empty body.");
3055 
3056  // When it is container-like - it is expected to hold a loop-like operation.
3057  if (isContainerLike()) {
3058  // Obtain the maximum collapse count - we use this to check that there
3059  // are enough loops contained.
3060  uint64_t collapseCount = getCollapseValue().value_or(1);
3061  if (getCollapseAttr()) {
3062  for (auto collapseEntry : getCollapseAttr()) {
3063  auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3064  if (intAttr.getValue().getZExtValue() > collapseCount)
3065  collapseCount = intAttr.getValue().getZExtValue();
3066  }
3067  }
3068 
3069  // We want to check that we find enough loop-like operations inside.
3070  // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3071  // level.
3072  mlir::Operation *expectedParent = this->getOperation();
3073  bool foundSibling = false;
3074  getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3075  if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3076  // This effectively checks that we are not looking at a sibling loop.
3077  if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3078  expectedParent) {
3079  foundSibling = true;
3080  return mlir::WalkResult::interrupt();
3081  }
3082 
3083  collapseCount--;
3084  expectedParent = op;
3085  }
3086  // We found enough contained loops.
3087  if (collapseCount == 0)
3088  return mlir::WalkResult::interrupt();
3089  return mlir::WalkResult::advance();
3090  });
3091 
3092  if (foundSibling)
3093  return emitError("found sibling loops inside container-like acc.loop");
3094  if (collapseCount != 0)
3095  return emitError("failed to find enough loop-like operations inside "
3096  "container-like acc.loop");
3097  }
3098 
3099  return success();
3100 }
3101 
3102 unsigned LoopOp::getNumDataOperands() {
3103  return getReductionOperands().size() + getPrivateOperands().size() +
3104  getFirstprivateOperands().size();
3105 }
3106 
3107 Value LoopOp::getDataOperand(unsigned i) {
3108  unsigned numOptional =
3109  getLowerbound().size() + getUpperbound().size() + getStep().size();
3110  numOptional += getGangOperands().size();
3111  numOptional += getVectorOperands().size();
3112  numOptional += getWorkerNumOperands().size();
3113  numOptional += getTileOperands().size();
3114  numOptional += getCacheOperands().size();
3115  return getOperand(numOptional + i);
3116 }
3117 
3118 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3119 
3120 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3121  return hasDeviceType(getAuto_(), deviceType);
3122 }
3123 
3124 bool LoopOp::hasIndependent() {
3125  return hasIndependent(mlir::acc::DeviceType::None);
3126 }
3127 
3128 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3129  return hasDeviceType(getIndependent(), deviceType);
3130 }
3131 
3132 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3133 
3134 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3135  return hasDeviceType(getSeq(), deviceType);
3136 }
3137 
3138 mlir::Value LoopOp::getVectorValue() {
3139  return getVectorValue(mlir::acc::DeviceType::None);
3140 }
3141 
3142 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3143  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3144  getVectorOperands(), deviceType);
3145 }
3146 
3147 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3148 
3149 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3150  return hasDeviceType(getVector(), deviceType);
3151 }
3152 
3153 mlir::Value LoopOp::getWorkerValue() {
3154  return getWorkerValue(mlir::acc::DeviceType::None);
3155 }
3156 
3157 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3158  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3159  getWorkerNumOperands(), deviceType);
3160 }
3161 
3162 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3163 
3164 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3165  return hasDeviceType(getWorker(), deviceType);
3166 }
3167 
3168 mlir::Operation::operand_range LoopOp::getTileValues() {
3169  return getTileValues(mlir::acc::DeviceType::None);
3170 }
3171 
3173 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3174  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3175  getTileOperandsSegments(), deviceType);
3176 }
3177 
3178 std::optional<int64_t> LoopOp::getCollapseValue() {
3179  return getCollapseValue(mlir::acc::DeviceType::None);
3180 }
3181 
3182 std::optional<int64_t>
3183 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3184  if (!getCollapseAttr())
3185  return std::nullopt;
3186  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3187  auto intAttr =
3188  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3189  return intAttr.getValue().getZExtValue();
3190  }
3191  return std::nullopt;
3192 }
3193 
3194 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3195  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3196 }
3197 
3198 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3199  mlir::acc::DeviceType deviceType) {
3200  if (getGangOperands().empty())
3201  return {};
3202  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3203  int32_t nbOperandsBefore = 0;
3204  for (unsigned i = 0; i < *pos; ++i)
3205  nbOperandsBefore += (*getGangOperandsSegments())[i];
3207  getGangOperands()
3208  .drop_front(nbOperandsBefore)
3209  .take_front((*getGangOperandsSegments())[*pos]);
3210 
3211  int32_t argTypeIdx = nbOperandsBefore;
3212  for (auto value : values) {
3213  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3214  (*getGangOperandsArgType())[argTypeIdx]);
3215  if (gangArgTypeAttr.getValue() == gangArgType)
3216  return value;
3217  ++argTypeIdx;
3218  }
3219  }
3220  return {};
3221 }
3222 
3223 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3224 
3225 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3226  return hasDeviceType(getGang(), deviceType);
3227 }
3228 
3229 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3230  return {&getRegion()};
3231 }
3232 
3233 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3234 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3235 /// `(` ssa-id-and-type-list `)`
3236 /// region
3237 ParseResult
3240  SmallVectorImpl<Type> &lowerboundType,
3242  SmallVectorImpl<Type> &upperboundType,
3244  SmallVectorImpl<Type> &stepType) {
3245 
3246  SmallVector<OpAsmParser::Argument> inductionVars;
3247  if (succeeded(
3248  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3249  if (parser.parseLParen() ||
3250  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3251  /*allowType=*/true) ||
3252  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3253  parser.parseOperandList(lowerbound, inductionVars.size(),
3255  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3256  parser.parseKeyword("to") || parser.parseLParen() ||
3257  parser.parseOperandList(upperbound, inductionVars.size(),
3259  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3260  parser.parseKeyword("step") || parser.parseLParen() ||
3261  parser.parseOperandList(step, inductionVars.size(),
3263  parser.parseColonTypeList(stepType) || parser.parseRParen())
3264  return failure();
3265  }
3266  return parser.parseRegion(region, inductionVars);
3267 }
3268 
3270  ValueRange lowerbound, TypeRange lowerboundType,
3271  ValueRange upperbound, TypeRange upperboundType,
3272  ValueRange steps, TypeRange stepType) {
3273  ValueRange regionArgs = region.front().getArguments();
3274  if (!regionArgs.empty()) {
3275  p << acc::LoopOp::getControlKeyword() << "(";
3276  llvm::interleaveComma(regionArgs, p,
3277  [&p](Value v) { p << v << " : " << v.getType(); });
3278  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3279  << upperbound << " : " << upperboundType << ") " << " step (" << steps
3280  << " : " << stepType << ") ";
3281  }
3282  p.printRegion(region, /*printEntryBlockArgs=*/false);
3283 }
3284 
3285 void acc::LoopOp::addSeq(MLIRContext *context,
3286  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3287  setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3288  effectiveDeviceTypes));
3289 }
3290 
3291 void acc::LoopOp::addIndependent(
3292  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3293  setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3294  context, getIndependentAttr(), effectiveDeviceTypes));
3295 }
3296 
3297 void acc::LoopOp::addAuto(MLIRContext *context,
3298  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3299  setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3300  effectiveDeviceTypes));
3301 }
3302 
3303 void acc::LoopOp::setCollapseForDeviceTypes(
3304  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3305  llvm::APInt value) {
3307  llvm::SmallVector<mlir::Attribute> newDeviceTypes;
3308 
3309  assert((getCollapseAttr() == nullptr) ==
3310  (getCollapseDeviceTypeAttr() == nullptr));
3311  assert(value.getBitWidth() == 64);
3312 
3313  if (getCollapseAttr()) {
3314  for (const auto &existing :
3315  llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3316  newValues.push_back(std::get<0>(existing));
3317  newDeviceTypes.push_back(std::get<1>(existing));
3318  }
3319  }
3320 
3321  if (effectiveDeviceTypes.empty()) {
3322  // If the effective device-types list is empty, this is before there are any
3323  // being applied by device_type, so this should be added as a 'none'.
3324  newValues.push_back(
3325  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3326  newDeviceTypes.push_back(
3328  } else {
3329  for (DeviceType dt : effectiveDeviceTypes) {
3330  newValues.push_back(
3331  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3332  newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3333  }
3334  }
3335 
3336  setCollapseAttr(ArrayAttr::get(context, newValues));
3337  setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3338 }
3339 
3340 void acc::LoopOp::setTileForDeviceTypes(
3341  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3342  ValueRange values) {
3343  llvm::SmallVector<int32_t> segments;
3344  if (getTileOperandsSegments())
3345  llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3346 
3347  setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3348  context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3349  getTileOperandsMutable(), segments));
3350 
3351  setTileOperandsSegments(segments);
3352 }
3353 
3354 void acc::LoopOp::addVectorOperand(
3355  MLIRContext *context, mlir::Value newValue,
3356  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3357  setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3358  context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3359  newValue, getVectorOperandsMutable()));
3360 }
3361 
3362 void acc::LoopOp::addEmptyVector(
3363  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3364  setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3365  effectiveDeviceTypes));
3366 }
3367 
3368 void acc::LoopOp::addWorkerNumOperand(
3369  MLIRContext *context, mlir::Value newValue,
3370  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3371  setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3372  context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3373  newValue, getWorkerNumOperandsMutable()));
3374 }
3375 
3376 void acc::LoopOp::addEmptyWorker(
3377  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3378  setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3379  effectiveDeviceTypes));
3380 }
3381 
3382 void acc::LoopOp::addEmptyGang(
3383  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3384  setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3385  effectiveDeviceTypes));
3386 }
3387 
3388 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3389  auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3390  return attr.getValue() == dt;
3391  };
3392  auto testFromArr = [=](ArrayAttr arr) -> bool {
3393  return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3394  };
3395 
3396  if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3397  return true;
3398  if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3399  return true;
3400  if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3401  return true;
3402 
3403  return false;
3404 }
3405 
3406 bool acc::LoopOp::hasDefaultGangWorkerVector() {
3407  return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3408  hasGang() || getGangValue(GangArgType::Num) ||
3409  getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3410 }
3411 
3412 acc::LoopParMode
3413 acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3414  if (hasSeq(deviceType))
3415  return LoopParMode::loop_seq;
3416  if (hasAuto(deviceType))
3417  return LoopParMode::loop_auto;
3418  if (hasIndependent(deviceType))
3419  return LoopParMode::loop_independent;
3420  if (hasSeq())
3421  return LoopParMode::loop_seq;
3422  if (hasAuto())
3423  return LoopParMode::loop_auto;
3424  assert(hasIndependent() &&
3425  "loop must have default auto, seq, or independent");
3426  return LoopParMode::loop_independent;
3427 }
3428 
3429 void acc::LoopOp::addGangOperands(
3430  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3431  llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
3432  llvm::SmallVector<int32_t> segments;
3433  if (std::optional<ArrayRef<int32_t>> existingSegments =
3434  getGangOperandsSegments())
3435  llvm::copy(*existingSegments, std::back_inserter(segments));
3436 
3437  unsigned beforeCount = segments.size();
3438 
3439  setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3440  context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3441  getGangOperandsMutable(), segments));
3442 
3443  setGangOperandsSegments(segments);
3444 
3445  // This is a bit of extra work to make sure we update the 'types' correctly by
3446  // adding to the types collection the correct number of times. We could
3447  // potentially add something similar to the
3448  // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3449  // excessive for a one-off case.
3450  unsigned numAdded = segments.size() - beforeCount;
3451 
3452  if (numAdded > 0) {
3454  if (getGangOperandsArgTypeAttr())
3455  llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3456 
3457  for (auto i : llvm::index_range(0u, numAdded)) {
3458  llvm::transform(argTypes, std::back_inserter(gangTypes),
3459  [=](mlir::acc::GangArgType gangTy) {
3460  return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3461  });
3462  (void)i;
3463  }
3464 
3465  setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3466  }
3467 }
3468 
3469 void acc::LoopOp::addPrivatization(MLIRContext *context,
3470  mlir::acc::PrivateOp op,
3471  mlir::acc::PrivateRecipeOp recipe) {
3472  getPrivateOperandsMutable().append(op.getResult());
3473 
3475 
3476  if (getPrivatizationRecipesAttr())
3477  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3478 
3479  recipes.push_back(
3480  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3481  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3482 }
3483 
3484 void acc::LoopOp::addFirstPrivatization(
3485  MLIRContext *context, mlir::acc::FirstprivateOp op,
3486  mlir::acc::FirstprivateRecipeOp recipe) {
3487  getFirstprivateOperandsMutable().append(op.getResult());
3488 
3490 
3491  if (getFirstprivatizationRecipesAttr())
3492  llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3493 
3494  recipes.push_back(
3495  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3496  setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3497 }
3498 
3499 void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3500  mlir::acc::ReductionRecipeOp recipe) {
3501  getReductionOperandsMutable().append(op.getResult());
3502 
3504 
3505  if (getReductionRecipesAttr())
3506  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3507 
3508  recipes.push_back(
3509  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3510  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3511 }
3512 
3513 //===----------------------------------------------------------------------===//
3514 // DataOp
3515 //===----------------------------------------------------------------------===//
3516 
3517 LogicalResult acc::DataOp::verify() {
3518  // 2.6.5. Data Construct restriction
3519  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3520  // attach, or default clause must appear on a data construct.
3521  if (getOperands().empty() && !getDefaultAttr())
3522  return emitError("at least one operand or the default attribute "
3523  "must appear on the data operation");
3524 
3525  for (mlir::Value operand : getDataClauseOperands())
3526  if (isa<BlockArgument>(operand) ||
3527  !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3528  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3529  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3530  operand.getDefiningOp()))
3531  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3532  "as defining op");
3533 
3534  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
3535  return failure();
3536 
3537  return success();
3538 }
3539 
3540 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3541 
3542 Value DataOp::getDataOperand(unsigned i) {
3543  unsigned numOptional = getIfCond() ? 1 : 0;
3544  numOptional += getAsyncOperands().size() ? 1 : 0;
3545  numOptional += getWaitOperands().size();
3546  return getOperand(numOptional + i);
3547 }
3548 
3549 bool acc::DataOp::hasAsyncOnly() {
3550  return hasAsyncOnly(mlir::acc::DeviceType::None);
3551 }
3552 
3553 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3554  return hasDeviceType(getAsyncOnly(), deviceType);
3555 }
3556 
3557 mlir::Value DataOp::getAsyncValue() {
3558  return getAsyncValue(mlir::acc::DeviceType::None);
3559 }
3560 
3561 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3563  getAsyncOperands(), deviceType);
3564 }
3565 
3566 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3567 
3568 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3569  return hasDeviceType(getWaitOnly(), deviceType);
3570 }
3571 
3572 mlir::Operation::operand_range DataOp::getWaitValues() {
3573  return getWaitValues(mlir::acc::DeviceType::None);
3574 }
3575 
3577 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3579  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3580  getHasWaitDevnum(), deviceType);
3581 }
3582 
3583 mlir::Value DataOp::getWaitDevnum() {
3584  return getWaitDevnum(mlir::acc::DeviceType::None);
3585 }
3586 
3587 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3588  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3589  getWaitOperandsSegments(), getHasWaitDevnum(),
3590  deviceType);
3591 }
3592 
3593 void acc::DataOp::addAsyncOnly(
3594  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3595  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3596  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3597 }
3598 
3599 void acc::DataOp::addAsyncOperand(
3600  MLIRContext *context, mlir::Value newValue,
3601  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3602  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3603  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3604  getAsyncOperandsMutable()));
3605 }
3606 
3607 void acc::DataOp::addWaitOnly(MLIRContext *context,
3608  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3609  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3610  effectiveDeviceTypes));
3611 }
3612 
3613 void acc::DataOp::addWaitOperands(
3614  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3615  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3616 
3617  llvm::SmallVector<int32_t> segments;
3618  if (getWaitOperandsSegments())
3619  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3620 
3621  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3622  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3623  getWaitOperandsMutable(), segments));
3624  setWaitOperandsSegments(segments);
3625 
3627  if (getHasWaitDevnumAttr())
3628  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3629  hasDevnums.insert(
3630  hasDevnums.end(),
3631  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3632  mlir::BoolAttr::get(context, hasDevnum));
3633  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3634 }
3635 
3636 //===----------------------------------------------------------------------===//
3637 // ExitDataOp
3638 //===----------------------------------------------------------------------===//
3639 
3640 LogicalResult acc::ExitDataOp::verify() {
3641  // 2.6.6. Data Exit Directive restriction
3642  // At least one copyout, delete, or detach clause must appear on an exit data
3643  // directive.
3644  if (getDataClauseOperands().empty())
3645  return emitError("at least one operand must be present in dataOperands on "
3646  "the exit data operation");
3647 
3648  // The async attribute represent the async clause without value. Therefore the
3649  // attribute and operand cannot appear at the same time.
3650  if (getAsyncOperand() && getAsync())
3651  return emitError("async attribute cannot appear with asyncOperand");
3652 
3653  // The wait attribute represent the wait clause without values. Therefore the
3654  // attribute and operands cannot appear at the same time.
3655  if (!getWaitOperands().empty() && getWait())
3656  return emitError("wait attribute cannot appear with waitOperands");
3657 
3658  if (getWaitDevnum() && getWaitOperands().empty())
3659  return emitError("wait_devnum cannot appear without waitOperands");
3660 
3661  return success();
3662 }
3663 
3664 unsigned ExitDataOp::getNumDataOperands() {
3665  return getDataClauseOperands().size();
3666 }
3667 
3668 Value ExitDataOp::getDataOperand(unsigned i) {
3669  unsigned numOptional = getIfCond() ? 1 : 0;
3670  numOptional += getAsyncOperand() ? 1 : 0;
3671  numOptional += getWaitDevnum() ? 1 : 0;
3672  return getOperand(getWaitOperands().size() + numOptional + i);
3673 }
3674 
3675 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3676  MLIRContext *context) {
3677  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3678 }
3679 
3680 void ExitDataOp::addAsyncOnly(MLIRContext *context,
3681  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3682  assert(effectiveDeviceTypes.empty());
3683  assert(!getAsyncAttr());
3684  assert(!getAsyncOperand());
3685 
3686  setAsyncAttr(mlir::UnitAttr::get(context));
3687 }
3688 
3689 void ExitDataOp::addAsyncOperand(
3690  MLIRContext *context, mlir::Value newValue,
3691  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3692  assert(effectiveDeviceTypes.empty());
3693  assert(!getAsyncAttr());
3694  assert(!getAsyncOperand());
3695 
3696  getAsyncOperandMutable().append(newValue);
3697 }
3698 
3699 void ExitDataOp::addWaitOnly(MLIRContext *context,
3700  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3701  assert(effectiveDeviceTypes.empty());
3702  assert(!getWaitAttr());
3703  assert(getWaitOperands().empty());
3704  assert(!getWaitDevnum());
3705 
3706  setWaitAttr(mlir::UnitAttr::get(context));
3707 }
3708 
3709 void ExitDataOp::addWaitOperands(
3710  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3711  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3712  assert(effectiveDeviceTypes.empty());
3713  assert(!getWaitAttr());
3714  assert(getWaitOperands().empty());
3715  assert(!getWaitDevnum());
3716 
3717  // if hasDevnum, the first value is the devnum. The 'rest' go into the
3718  // operands list.
3719  if (hasDevnum) {
3720  getWaitDevnumMutable().append(newValues.front());
3721  newValues = newValues.drop_front();
3722  }
3723 
3724  getWaitOperandsMutable().append(newValues);
3725 }
3726 
3727 //===----------------------------------------------------------------------===//
3728 // EnterDataOp
3729 //===----------------------------------------------------------------------===//
3730 
3731 LogicalResult acc::EnterDataOp::verify() {
3732  // 2.6.6. Data Enter Directive restriction
3733  // At least one copyin, create, or attach clause must appear on an enter data
3734  // directive.
3735  if (getDataClauseOperands().empty())
3736  return emitError("at least one operand must be present in dataOperands on "
3737  "the enter data operation");
3738 
3739  // The async attribute represent the async clause without value. Therefore the
3740  // attribute and operand cannot appear at the same time.
3741  if (getAsyncOperand() && getAsync())
3742  return emitError("async attribute cannot appear with asyncOperand");
3743 
3744  // The wait attribute represent the wait clause without values. Therefore the
3745  // attribute and operands cannot appear at the same time.
3746  if (!getWaitOperands().empty() && getWait())
3747  return emitError("wait attribute cannot appear with waitOperands");
3748 
3749  if (getWaitDevnum() && getWaitOperands().empty())
3750  return emitError("wait_devnum cannot appear without waitOperands");
3751 
3752  for (mlir::Value operand : getDataClauseOperands())
3753  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3754  operand.getDefiningOp()))
3755  return emitError("expect data entry operation as defining op");
3756 
3757  return success();
3758 }
3759 
3760 unsigned EnterDataOp::getNumDataOperands() {
3761  return getDataClauseOperands().size();
3762 }
3763 
3764 Value EnterDataOp::getDataOperand(unsigned i) {
3765  unsigned numOptional = getIfCond() ? 1 : 0;
3766  numOptional += getAsyncOperand() ? 1 : 0;
3767  numOptional += getWaitDevnum() ? 1 : 0;
3768  return getOperand(getWaitOperands().size() + numOptional + i);
3769 }
3770 
3771 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3772  MLIRContext *context) {
3773  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3774 }
3775 
3776 void EnterDataOp::addAsyncOnly(
3777  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3778  assert(effectiveDeviceTypes.empty());
3779  assert(!getAsyncAttr());
3780  assert(!getAsyncOperand());
3781 
3782  setAsyncAttr(mlir::UnitAttr::get(context));
3783 }
3784 
3785 void EnterDataOp::addAsyncOperand(
3786  MLIRContext *context, mlir::Value newValue,
3787  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3788  assert(effectiveDeviceTypes.empty());
3789  assert(!getAsyncAttr());
3790  assert(!getAsyncOperand());
3791 
3792  getAsyncOperandMutable().append(newValue);
3793 }
3794 
3795 void EnterDataOp::addWaitOnly(MLIRContext *context,
3796  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3797  assert(effectiveDeviceTypes.empty());
3798  assert(!getWaitAttr());
3799  assert(getWaitOperands().empty());
3800  assert(!getWaitDevnum());
3801 
3802  setWaitAttr(mlir::UnitAttr::get(context));
3803 }
3804 
3805 void EnterDataOp::addWaitOperands(
3806  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3807  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3808  assert(effectiveDeviceTypes.empty());
3809  assert(!getWaitAttr());
3810  assert(getWaitOperands().empty());
3811  assert(!getWaitDevnum());
3812 
3813  // if hasDevnum, the first value is the devnum. The 'rest' go into the
3814  // operands list.
3815  if (hasDevnum) {
3816  getWaitDevnumMutable().append(newValues.front());
3817  newValues = newValues.drop_front();
3818  }
3819 
3820  getWaitOperandsMutable().append(newValues);
3821 }
3822 
3823 //===----------------------------------------------------------------------===//
3824 // AtomicReadOp
3825 //===----------------------------------------------------------------------===//
3826 
3827 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3828 
3829 //===----------------------------------------------------------------------===//
3830 // AtomicWriteOp
3831 //===----------------------------------------------------------------------===//
3832 
3833 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3834 
3835 //===----------------------------------------------------------------------===//
3836 // AtomicUpdateOp
3837 //===----------------------------------------------------------------------===//
3838 
3839 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3840  PatternRewriter &rewriter) {
3841  if (op.isNoOp()) {
3842  rewriter.eraseOp(op);
3843  return success();
3844  }
3845 
3846  if (Value writeVal = op.getWriteOpVal()) {
3847  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3848  return success();
3849  }
3850 
3851  return failure();
3852 }
3853 
3854 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3855 
3856 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3857 
3858 //===----------------------------------------------------------------------===//
3859 // AtomicCaptureOp
3860 //===----------------------------------------------------------------------===//
3861 
3862 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3863  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3864  return op;
3865  return dyn_cast<AtomicReadOp>(getSecondOp());
3866 }
3867 
3868 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3869  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3870  return op;
3871  return dyn_cast<AtomicWriteOp>(getSecondOp());
3872 }
3873 
3874 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3875  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3876  return op;
3877  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3878 }
3879 
3880 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3881 
3882 //===----------------------------------------------------------------------===//
3883 // DeclareEnterOp
3884 //===----------------------------------------------------------------------===//
3885 
3886 template <typename Op>
3887 static LogicalResult
3889  bool requireAtLeastOneOperand = true) {
3890  if (operands.empty() && requireAtLeastOneOperand)
3891  return emitError(
3892  op->getLoc(),
3893  "at least one operand must appear on the declare operation");
3894 
3895  for (mlir::Value operand : operands) {
3896  if (isa<BlockArgument>(operand) ||
3897  !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3898  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3899  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3900  operand.getDefiningOp()))
3901  return op.emitError(
3902  "expect valid declare data entry operation or acc.getdeviceptr "
3903  "as defining op");
3904 
3905  mlir::Value var{getVar(operand.getDefiningOp())};
3906  assert(var && "declare operands can only be data entry operations which "
3907  "must have var");
3908  (void)var;
3909  std::optional<mlir::acc::DataClause> dataClauseOptional{
3910  getDataClause(operand.getDefiningOp())};
3911  assert(dataClauseOptional.has_value() &&
3912  "declare operands can only be data entry operations which must have "
3913  "dataClause");
3914  (void)dataClauseOptional;
3915  }
3916 
3917  return success();
3918 }
3919 
3920 LogicalResult acc::DeclareEnterOp::verify() {
3921  return checkDeclareOperands(*this, this->getDataClauseOperands());
3922 }
3923 
3924 //===----------------------------------------------------------------------===//
3925 // DeclareExitOp
3926 //===----------------------------------------------------------------------===//
3927 
3928 LogicalResult acc::DeclareExitOp::verify() {
3929  if (getToken())
3930  return checkDeclareOperands(*this, this->getDataClauseOperands(),
3931  /*requireAtLeastOneOperand=*/false);
3932  return checkDeclareOperands(*this, this->getDataClauseOperands());
3933 }
3934 
3935 //===----------------------------------------------------------------------===//
3936 // DeclareOp
3937 //===----------------------------------------------------------------------===//
3938 
3939 LogicalResult acc::DeclareOp::verify() {
3940  return checkDeclareOperands(*this, this->getDataClauseOperands());
3941 }
3942 
3943 //===----------------------------------------------------------------------===//
3944 // RoutineOp
3945 //===----------------------------------------------------------------------===//
3946 
3947 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3948  acc::DeviceType dtype) {
3949  unsigned parallelism = 0;
3950  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3951  parallelism += op.hasWorker(dtype) ? 1 : 0;
3952  parallelism += op.hasVector(dtype) ? 1 : 0;
3953  parallelism += op.hasSeq(dtype) ? 1 : 0;
3954  return parallelism;
3955 }
3956 
3957 LogicalResult acc::RoutineOp::verify() {
3958  unsigned baseParallelism =
3960 
3961  if (baseParallelism > 1)
3962  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3963  "be present at the same time";
3964 
3965  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3966  ++dtypeInt) {
3967  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3968  if (dtype == acc::DeviceType::None)
3969  continue;
3970  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
3971 
3972  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3973  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3974  "be present at the same time";
3975  }
3976 
3977  return success();
3978 }
3979 
3980 static ParseResult parseBindName(OpAsmParser &parser,
3981  mlir::ArrayAttr &bindIdName,
3982  mlir::ArrayAttr &bindStrName,
3983  mlir::ArrayAttr &deviceIdTypes,
3984  mlir::ArrayAttr &deviceStrTypes) {
3985  llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
3986  llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
3987  llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
3988  llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
3989 
3990  if (failed(parser.parseCommaSeparatedList([&]() {
3991  mlir::Attribute newAttr;
3992  bool isSymbolRefAttr;
3993  auto parseResult = parser.parseAttribute(newAttr);
3994  if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3995  bindIdNameAttrs.push_back(symbolRefAttr);
3996  isSymbolRefAttr = true;
3997  } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3998  bindStrNameAttrs.push_back(stringAttr);
3999  isSymbolRefAttr = false;
4000  }
4001  if (parseResult)
4002  return failure();
4003  if (failed(parser.parseOptionalLSquare())) {
4004  if (isSymbolRefAttr) {
4005  deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4006  parser.getContext(), mlir::acc::DeviceType::None));
4007  } else {
4008  deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4009  parser.getContext(), mlir::acc::DeviceType::None));
4010  }
4011  } else {
4012  if (isSymbolRefAttr) {
4013  if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4014  parser.parseRSquare())
4015  return failure();
4016  } else {
4017  if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4018  parser.parseRSquare())
4019  return failure();
4020  }
4021  }
4022  return success();
4023  })))
4024  return failure();
4025 
4026  bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4027  bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4028  deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4029  deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4030 
4031  return success();
4032 }
4033 
4035  std::optional<mlir::ArrayAttr> bindIdName,
4036  std::optional<mlir::ArrayAttr> bindStrName,
4037  std::optional<mlir::ArrayAttr> deviceIdTypes,
4038  std::optional<mlir::ArrayAttr> deviceStrTypes) {
4039  // Create combined vectors for all bind names and device types
4041  llvm::SmallVector<mlir::Attribute> allDeviceTypes;
4042 
4043  // Append bindIdName and deviceIdTypes
4044  if (hasDeviceTypeValues(deviceIdTypes)) {
4045  allBindNames.append(bindIdName->begin(), bindIdName->end());
4046  allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4047  }
4048 
4049  // Append bindStrName and deviceStrTypes
4050  if (hasDeviceTypeValues(deviceStrTypes)) {
4051  allBindNames.append(bindStrName->begin(), bindStrName->end());
4052  allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4053  }
4054 
4055  // Print the combined sequence
4056  if (!allBindNames.empty())
4057  llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4058  [&](const auto &pair) {
4059  p << std::get<0>(pair);
4060  printSingleDeviceType(p, std::get<1>(pair));
4061  });
4062 }
4063 
4064 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4065  mlir::ArrayAttr &gang,
4066  mlir::ArrayAttr &gangDim,
4067  mlir::ArrayAttr &gangDimDeviceTypes) {
4068 
4069  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4070  gangDimDeviceTypeAttrs;
4071  bool needCommaBeforeOperands = false;
4072 
4073  // Gang keyword only
4074  if (failed(parser.parseOptionalLParen())) {
4075  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4077  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4078  return success();
4079  }
4080 
4081  // Parse keyword only attributes
4082  if (succeeded(parser.parseOptionalLSquare())) {
4083  if (failed(parser.parseCommaSeparatedList([&]() {
4084  if (parser.parseAttribute(gangAttrs.emplace_back()))
4085  return failure();
4086  return success();
4087  })))
4088  return failure();
4089  if (parser.parseRSquare())
4090  return failure();
4091  needCommaBeforeOperands = true;
4092  }
4093 
4094  if (needCommaBeforeOperands && failed(parser.parseComma()))
4095  return failure();
4096 
4097  if (failed(parser.parseCommaSeparatedList([&]() {
4098  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4099  parser.parseColon() ||
4100  parser.parseAttribute(gangDimAttrs.emplace_back()))
4101  return failure();
4102  if (succeeded(parser.parseOptionalLSquare())) {
4103  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4104  parser.parseRSquare())
4105  return failure();
4106  } else {
4107  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4108  parser.getContext(), mlir::acc::DeviceType::None));
4109  }
4110  return success();
4111  })))
4112  return failure();
4113 
4114  if (failed(parser.parseRParen()))
4115  return failure();
4116 
4117  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4118  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4119  gangDimDeviceTypes =
4120  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4121 
4122  return success();
4123 }
4124 
4126  std::optional<mlir::ArrayAttr> gang,
4127  std::optional<mlir::ArrayAttr> gangDim,
4128  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4129 
4130  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4131  gang->size() == 1) {
4132  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4133  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4134  return;
4135  }
4136 
4137  p << "(";
4138 
4139  printDeviceTypes(p, gang);
4140 
4141  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4142  p << ", ";
4143 
4144  if (hasDeviceTypeValues(gangDimDeviceTypes))
4145  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4146  [&](const auto &pair) {
4147  p << acc::RoutineOp::getGangDimKeyword() << ": ";
4148  p << std::get<0>(pair);
4149  printSingleDeviceType(p, std::get<1>(pair));
4150  });
4151 
4152  p << ")";
4153 }
4154 
4155 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4156  mlir::ArrayAttr &deviceTypes) {
4158  // Keyword only
4159  if (failed(parser.parseOptionalLParen())) {
4160  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4162  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4163  return success();
4164  }
4165 
4166  // Parse device type attributes
4167  if (succeeded(parser.parseOptionalLSquare())) {
4168  if (failed(parser.parseCommaSeparatedList([&]() {
4169  if (parser.parseAttribute(attributes.emplace_back()))
4170  return failure();
4171  return success();
4172  })))
4173  return failure();
4174  if (parser.parseRSquare() || parser.parseRParen())
4175  return failure();
4176  }
4177  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4178  return success();
4179 }
4180 
4181 static void
4183  std::optional<mlir::ArrayAttr> deviceTypes) {
4184 
4185  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4186  auto deviceTypeAttr =
4187  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4188  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4189  return;
4190  }
4191 
4192  if (!hasDeviceTypeValues(deviceTypes))
4193  return;
4194 
4195  p << "([";
4196  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4197  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4198  p << dTypeAttr;
4199  });
4200  p << "])";
4201 }
4202 
4203 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4204 
4205 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4206  return hasDeviceType(getWorker(), deviceType);
4207 }
4208 
4209 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4210 
4211 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4212  return hasDeviceType(getVector(), deviceType);
4213 }
4214 
4215 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4216 
4217 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4218  return hasDeviceType(getSeq(), deviceType);
4219 }
4220 
4221 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4222 RoutineOp::getBindNameValue() {
4223  return getBindNameValue(mlir::acc::DeviceType::None);
4224 }
4225 
4226 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4227 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4228  if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4229  !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4230  return std::nullopt;
4231  }
4232 
4233  if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4234  auto attr = (*getBindIdName())[*pos];
4235  auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4236  assert(symbolRefAttr && "expected SymbolRef");
4237  return symbolRefAttr;
4238  }
4239 
4240  if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4241  auto attr = (*getBindStrName())[*pos];
4242  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4243  assert(stringAttr && "expected String");
4244  return stringAttr;
4245  }
4246 
4247  return std::nullopt;
4248 }
4249 
4250 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4251 
4252 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4253  return hasDeviceType(getGang(), deviceType);
4254 }
4255 
4256 std::optional<int64_t> RoutineOp::getGangDimValue() {
4257  return getGangDimValue(mlir::acc::DeviceType::None);
4258 }
4259 
4260 std::optional<int64_t>
4261 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4262  if (!hasDeviceTypeValues(getGangDimDeviceType()))
4263  return std::nullopt;
4264  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4265  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4266  return intAttr.getInt();
4267  }
4268  return std::nullopt;
4269 }
4270 
4271 //===----------------------------------------------------------------------===//
4272 // InitOp
4273 //===----------------------------------------------------------------------===//
4274 
4275 LogicalResult acc::InitOp::verify() {
4276  Operation *currOp = *this;
4277  while ((currOp = currOp->getParentOp()))
4278  if (isComputeOperation(currOp))
4279  return emitOpError("cannot be nested in a compute operation");
4280  return success();
4281 }
4282 
4283 void acc::InitOp::addDeviceType(MLIRContext *context,
4284  mlir::acc::DeviceType deviceType) {
4286  if (getDeviceTypesAttr())
4287  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4288 
4289  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4290  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4291 }
4292 
4293 //===----------------------------------------------------------------------===//
4294 // ShutdownOp
4295 //===----------------------------------------------------------------------===//
4296 
4297 LogicalResult acc::ShutdownOp::verify() {
4298  Operation *currOp = *this;
4299  while ((currOp = currOp->getParentOp()))
4300  if (isComputeOperation(currOp))
4301  return emitOpError("cannot be nested in a compute operation");
4302  return success();
4303 }
4304 
4305 void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4306  mlir::acc::DeviceType deviceType) {
4308  if (getDeviceTypesAttr())
4309  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4310 
4311  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4312  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4313 }
4314 
4315 //===----------------------------------------------------------------------===//
4316 // SetOp
4317 //===----------------------------------------------------------------------===//
4318 
4319 LogicalResult acc::SetOp::verify() {
4320  Operation *currOp = *this;
4321  while ((currOp = currOp->getParentOp()))
4322  if (isComputeOperation(currOp))
4323  return emitOpError("cannot be nested in a compute operation");
4324  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4325  return emitOpError("at least one default_async, device_num, or device_type "
4326  "operand must appear");
4327  return success();
4328 }
4329 
4330 //===----------------------------------------------------------------------===//
4331 // UpdateOp
4332 //===----------------------------------------------------------------------===//
4333 
4334 LogicalResult acc::UpdateOp::verify() {
4335  // At least one of host or device should have a value.
4336  if (getDataClauseOperands().empty())
4337  return emitError("at least one value must be present in dataOperands");
4338 
4340  getAsyncOperandsDeviceTypeAttr(),
4341  "async")))
4342  return failure();
4343 
4345  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4346  getWaitOperandsDeviceTypeAttr(), "wait")))
4347  return failure();
4348 
4349  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
4350  return failure();
4351 
4352  for (mlir::Value operand : getDataClauseOperands())
4353  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4354  operand.getDefiningOp()))
4355  return emitError("expect data entry/exit operation or acc.getdeviceptr "
4356  "as defining op");
4357 
4358  return success();
4359 }
4360 
4361 unsigned UpdateOp::getNumDataOperands() {
4362  return getDataClauseOperands().size();
4363 }
4364 
4365 Value UpdateOp::getDataOperand(unsigned i) {
4366  unsigned numOptional = getAsyncOperands().size();
4367  numOptional += getIfCond() ? 1 : 0;
4368  return getOperand(getWaitOperands().size() + numOptional + i);
4369 }
4370 
4371 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4372  MLIRContext *context) {
4373  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4374 }
4375 
4376 bool UpdateOp::hasAsyncOnly() {
4377  return hasAsyncOnly(mlir::acc::DeviceType::None);
4378 }
4379 
4380 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4381  return hasDeviceType(getAsyncOnly(), deviceType);
4382 }
4383 
4384 mlir::Value UpdateOp::getAsyncValue() {
4385  return getAsyncValue(mlir::acc::DeviceType::None);
4386 }
4387 
4388 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4390  return {};
4391 
4392  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4393  return getAsyncOperands()[*pos];
4394 
4395  return {};
4396 }
4397 
4398 bool UpdateOp::hasWaitOnly() {
4399  return hasWaitOnly(mlir::acc::DeviceType::None);
4400 }
4401 
4402 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4403  return hasDeviceType(getWaitOnly(), deviceType);
4404 }
4405 
4406 mlir::Operation::operand_range UpdateOp::getWaitValues() {
4407  return getWaitValues(mlir::acc::DeviceType::None);
4408 }
4409 
4411 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4413  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4414  getHasWaitDevnum(), deviceType);
4415 }
4416 
4417 mlir::Value UpdateOp::getWaitDevnum() {
4418  return getWaitDevnum(mlir::acc::DeviceType::None);
4419 }
4420 
4421 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4422  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4423  getWaitOperandsSegments(), getHasWaitDevnum(),
4424  deviceType);
4425 }
4426 
4427 void UpdateOp::addAsyncOnly(MLIRContext *context,
4428  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4429  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4430  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4431 }
4432 
4433 void UpdateOp::addAsyncOperand(
4434  MLIRContext *context, mlir::Value newValue,
4435  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4436  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4437  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4438  getAsyncOperandsMutable()));
4439 }
4440 
4441 void UpdateOp::addWaitOnly(MLIRContext *context,
4442  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4443  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4444  effectiveDeviceTypes));
4445 }
4446 
4447 void UpdateOp::addWaitOperands(
4448  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4449  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4450 
4451  llvm::SmallVector<int32_t> segments;
4452  if (getWaitOperandsSegments())
4453  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4454 
4455  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4456  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4457  getWaitOperandsMutable(), segments));
4458  setWaitOperandsSegments(segments);
4459 
4461  if (getHasWaitDevnumAttr())
4462  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4463  hasDevnums.insert(
4464  hasDevnums.end(),
4465  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4466  mlir::BoolAttr::get(context, hasDevnum));
4467  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4468 }
4469 
4470 //===----------------------------------------------------------------------===//
4471 // WaitOp
4472 //===----------------------------------------------------------------------===//
4473 
4474 LogicalResult acc::WaitOp::verify() {
4475  // The async attribute represent the async clause without value. Therefore the
4476  // attribute and operand cannot appear at the same time.
4477  if (getAsyncOperand() && getAsync())
4478  return emitError("async attribute cannot appear with asyncOperand");
4479 
4480  if (getWaitDevnum() && getWaitOperands().empty())
4481  return emitError("wait_devnum cannot appear without waitOperands");
4482 
4483  return success();
4484 }
4485 
4486 #define GET_OP_CLASSES
4487 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4488 
4489 #define GET_ATTRDEF_CLASSES
4490 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4491 
4492 #define GET_TYPEDEF_CLASSES
4493 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4494 
4495 //===----------------------------------------------------------------------===//
4496 // acc dialect utilities
4497 //===----------------------------------------------------------------------===//
4498 
4501  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4503  accDataClauseOp)
4504  .Case<ACC_DATA_ENTRY_OPS>(
4505  [&](auto entry) { return entry.getVarPtr(); })
4506  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4507  [&](auto exit) { return exit.getVarPtr(); })
4508  .Default([&](mlir::Operation *) {
4510  })};
4511  return varPtr;
4512 }
4513 
4515  auto varPtr{
4517  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4518  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4519  return varPtr;
4520 }
4521 
4523  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4524  .Case<ACC_DATA_ENTRY_OPS>(
4525  [&](auto entry) { return entry.getVarType(); })
4526  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4527  [&](auto exit) { return exit.getVarType(); })
4528  .Default([&](mlir::Operation *) { return mlir::Type(); })};
4529  return varType;
4530 }
4531 
4534  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4536  accDataClauseOp)
4537  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4538  [&](auto dataClause) { return dataClause.getAccPtr(); })
4539  .Default([&](mlir::Operation *) {
4541  })};
4542  return accPtr;
4543 }
4544 
4546  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4548  [&](auto dataClause) { return dataClause.getAccVar(); })
4549  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4550  return accPtr;
4551 }
4552 
4554  auto varPtrPtr{
4556  .Case<ACC_DATA_ENTRY_OPS>(
4557  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4558  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4559  return varPtrPtr;
4560 }
4561 
4566  accDataClauseOp)
4567  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4569  dataClause.getBounds().begin(), dataClause.getBounds().end());
4570  })
4571  .Default([&](mlir::Operation *) {
4573  })};
4574  return bounds;
4575 }
4576 
4580  accDataClauseOp)
4581  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4583  dataClause.getAsyncOperands().begin(),
4584  dataClause.getAsyncOperands().end());
4585  })
4586  .Default([&](mlir::Operation *) {
4588  });
4589 }
4590 
4591 mlir::ArrayAttr
4594  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4595  return dataClause.getAsyncOperandsDeviceTypeAttr();
4596  })
4597  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4598 }
4599 
4600 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4603  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4604  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4605 }
4606 
4607 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4608  auto name{
4610  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4611  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4612  return {};
4613  })};
4614  return name;
4615 }
4616 
4617 std::optional<mlir::acc::DataClause>
4619  auto dataClause{
4621  accDataEntryOp)
4622  .Case<ACC_DATA_ENTRY_OPS>(
4623  [&](auto entry) { return entry.getDataClause(); })
4624  .Default([&](mlir::Operation *) { return std::nullopt; })};
4625  return dataClause;
4626 }
4627 
4629  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4630  .Case<ACC_DATA_ENTRY_OPS>(
4631  [&](auto entry) { return entry.getImplicit(); })
4632  .Default([&](mlir::Operation *) { return false; })};
4633  return implicit;
4634 }
4635 
4637  auto dataOperands{
4640  [&](auto entry) { return entry.getDataClauseOperands(); })
4641  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4642  return dataOperands;
4643 }
4644 
4647  auto dataOperands{
4650  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4651  .Default([&](mlir::Operation *) { return nullptr; })};
4652  return dataOperands;
4653 }
4654 
4656  mlir::Operation *parentOp = region.getParentOp();
4657  while (parentOp) {
4658  if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
4659  return parentOp;
4660  }
4661  parentOp = parentOp->getParentOp();
4662  }
4663  return nullptr;
4664 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:137
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2366
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:4125
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:930
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:2873
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:1471
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition: OpenACC.cpp:3980
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2887
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:944
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1993
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:499
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:468
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:2004
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1909
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:294
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:4182
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:2682
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:2244
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:3888
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:428
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:2198
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:312
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:410
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition: OpenACC.cpp:444
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:3238
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:1400
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2040
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1553
static LogicalResult checkNoModifier(Op op)
Definition: OpenACC.cpp:436
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:477
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:336
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1779
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:453
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:3269
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:4155
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:4064
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1892
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2067
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition: OpenACC.cpp:2183
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1846
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:1384
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:368
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:2159
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:540
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:2701
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:1156
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition: OpenACC.cpp:2228
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1823
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:323
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:1415
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:2140
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:298
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:2828
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:352
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:2078
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:511
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:388
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:1481
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:3947
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1829
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:2264
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:1364
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition: OpenACC.cpp:4034
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:53
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:385
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:390
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:831
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:837
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:4545
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:4514
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:4533
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:4618
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:4646
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:4563
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:4636
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:4607
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
Definition: OpenACC.cpp:4655
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:4628
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:4578
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:4553
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:4600
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:4522
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:4500
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:4592
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.