MLIR  21.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 
25 using namespace mlir;
26 using namespace acc;
27 
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33 
34 namespace {
35 
36 static bool isScalarLikeType(Type type) {
37  return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
38 }
39 
40 struct MemRefPointerLikeModel
41  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
42  MemRefType> {
43  Type getElementType(Type pointer) const {
44  return cast<MemRefType>(pointer).getElementType();
45  }
46  mlir::acc::VariableTypeCategory
47  getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48  Type varType) const {
49  if (auto mappableTy = dyn_cast<MappableType>(varType)) {
50  return mappableTy.getTypeCategory(varPtr);
51  }
52  auto memrefTy = cast<MemRefType>(pointer);
53  if (!memrefTy.hasRank()) {
54  // This memref is unranked - aka it could have any rank, including a
55  // rank of 0 which could mean scalar. For now, return uncategorized.
56  return mlir::acc::VariableTypeCategory::uncategorized;
57  }
58 
59  if (memrefTy.getRank() == 0) {
60  if (isScalarLikeType(memrefTy.getElementType())) {
61  return mlir::acc::VariableTypeCategory::scalar;
62  }
63  // Zero-rank non-scalar - need further analysis to determine the type
64  // category. For now, return uncategorized.
65  return mlir::acc::VariableTypeCategory::uncategorized;
66  }
67 
68  // It has a rank - must be an array.
69  assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70  return mlir::acc::VariableTypeCategory::array;
71  }
72 };
73 
74 struct LLVMPointerPointerLikeModel
75  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76  LLVM::LLVMPointerType> {
77  Type getElementType(Type pointer) const { return Type(); }
78 };
79 
80 /// Helper function for any of the times we need to modify an ArrayAttr based on
81 /// a device type list. Returns a new ArrayAttr with all of the
82 /// existingDeviceTypes, plus the effective new ones(or an added none if hte new
83 /// list is empty).
84 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
86  llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
88  if (existingDeviceTypes)
89  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
90 
91  if (newDeviceTypes.empty())
92  deviceTypes.push_back(
94 
95  for (DeviceType DT : newDeviceTypes)
96  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
97 
98  return mlir::ArrayAttr::get(context, deviceTypes);
99 }
100 
101 /// Helper function for any of the times we need to add operands that are
102 /// affected by a device type list. Returns a new ArrayAttr with all of the
103 /// existingDeviceTypes, plus the effective new ones (or an added none, if the
104 /// new list is empty). Additionally, adds the arguments to the argCollection
105 /// the correct number of times. This will also update a 'segments' array, even
106 /// if it won't be used.
107 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
109  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
110  mlir::MutableOperandRange argCollection,
111  llvm::SmallVector<int32_t> &segments) {
113  if (existingDeviceTypes)
114  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
115 
116  if (newDeviceTypes.empty()) {
117  argCollection.append(arguments);
118  segments.push_back(arguments.size());
119  deviceTypes.push_back(
121  }
122 
123  for (DeviceType DT : newDeviceTypes) {
124  argCollection.append(arguments);
125  segments.push_back(arguments.size());
126  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
127  }
128 
129  return mlir::ArrayAttr::get(context, deviceTypes);
130 }
131 
132 /// Overload for when the 'segments' aren't needed.
133 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
135  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
136  mlir::MutableOperandRange argCollection) {
138  return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139  newDeviceTypes, arguments,
140  argCollection, segments);
141 }
142 } // namespace
143 
144 //===----------------------------------------------------------------------===//
145 // OpenACC operations
146 //===----------------------------------------------------------------------===//
147 
148 void OpenACCDialect::initialize() {
149  addOperations<
150 #define GET_OP_LIST
151 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
152  >();
153  addAttributes<
154 #define GET_ATTRDEF_LIST
155 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
156  >();
157  addTypes<
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
160  >();
161 
162  // By attaching interfaces here, we make the OpenACC dialect dependent on
163  // the other dialects. This is probably better than having dialects like LLVM
164  // and memref be dependent on OpenACC.
165  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
166  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
167  *getContext());
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // device_type support helpers
172 //===----------------------------------------------------------------------===//
173 
174 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
175  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
176  return true;
177  return false;
178 }
179 
180 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
181  mlir::acc::DeviceType deviceType) {
182  if (!hasDeviceTypeValues(arrayAttr))
183  return false;
184 
185  for (auto attr : *arrayAttr) {
186  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
187  if (deviceTypeAttr.getValue() == deviceType)
188  return true;
189  }
190 
191  return false;
192 }
193 
195  std::optional<mlir::ArrayAttr> deviceTypes) {
196  if (!hasDeviceTypeValues(deviceTypes))
197  return;
198 
199  p << "[";
200  llvm::interleaveComma(*deviceTypes, p,
201  [&](mlir::Attribute attr) { p << attr; });
202  p << "]";
203 }
204 
205 static std::optional<unsigned> findSegment(ArrayAttr segments,
206  mlir::acc::DeviceType deviceType) {
207  unsigned segmentIdx = 0;
208  for (auto attr : segments) {
209  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
210  if (deviceTypeAttr.getValue() == deviceType)
211  return std::make_optional(segmentIdx);
212  ++segmentIdx;
213  }
214  return std::nullopt;
215 }
216 
218 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
220  std::optional<llvm::ArrayRef<int32_t>> segments,
221  mlir::acc::DeviceType deviceType) {
222  if (!arrayAttr)
223  return range.take_front(0);
224  if (auto pos = findSegment(*arrayAttr, deviceType)) {
225  int32_t nbOperandsBefore = 0;
226  for (unsigned i = 0; i < *pos; ++i)
227  nbOperandsBefore += (*segments)[i];
228  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
229  }
230  return range.take_front(0);
231 }
232 
233 static mlir::Value
234 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
236  std::optional<llvm::ArrayRef<int32_t>> segments,
237  std::optional<mlir::ArrayAttr> hasWaitDevnum,
238  mlir::acc::DeviceType deviceType) {
239  if (!hasDeviceTypeValues(deviceTypeAttr))
240  return {};
241  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
242  if (hasWaitDevnum->getValue()[*pos])
243  return getValuesFromSegments(deviceTypeAttr, operands, segments,
244  deviceType)
245  .front();
246  return {};
247 }
248 
250 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
252  std::optional<llvm::ArrayRef<int32_t>> segments,
253  std::optional<mlir::ArrayAttr> hasWaitDevnum,
254  mlir::acc::DeviceType deviceType) {
255  auto range =
256  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
257  if (range.empty())
258  return range;
259  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
260  if (hasWaitDevnum && *hasWaitDevnum) {
261  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
262  if (boolAttr.getValue())
263  return range.drop_front(1); // first value is devnum
264  }
265  }
266  return range;
267 }
268 
269 template <typename Op>
270 static LogicalResult checkWaitAndAsyncConflict(Op op) {
271  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272  ++dtypeInt) {
273  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
274 
275  // The asyncOnly attribute represent the async clause without value.
276  // Therefore the attribute and operand cannot appear at the same time.
277  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
278  op.hasAsyncOnly(dtype))
279  return op.emitError(
280  "asyncOnly attribute cannot appear with asyncOperand");
281 
282  // The wait attribute represent the wait clause without values. Therefore
283  // the attribute and operands cannot appear at the same time.
284  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
285  op.hasWaitOnly(dtype))
286  return op.emitError("wait attribute cannot appear with waitOperands");
287  }
288  return success();
289 }
290 
291 template <typename Op>
292 static LogicalResult checkVarAndVarType(Op op) {
293  if (!op.getVar())
294  return op.emitError("must have var operand");
295 
296  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297  mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
298  // TODO: If a type implements both interfaces (mappable and pointer-like),
299  // it is unclear which semantics to apply without additional info which
300  // would need captured in the data operation. For now restrict this case
301  // unless a compelling reason to support disambiguating between the two.
302  return op.emitError("var must be mappable or pointer-like (not both)");
303  }
304 
305  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
306  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
307  return op.emitError("var must be mappable or pointer-like");
308 
309  if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
310  op.getVarType() != op.getVar().getType())
311  return op.emitError("varType must match when var is mappable");
312 
313  return success();
314 }
315 
316 template <typename Op>
317 static LogicalResult checkVarAndAccVar(Op op) {
318  if (op.getVar().getType() != op.getAccVar().getType())
319  return op.emitError("input and output types must match");
320 
321  return success();
322 }
323 
324 static ParseResult parseVar(mlir::OpAsmParser &parser,
326  // Either `var` or `varPtr` keyword is required.
327  if (failed(parser.parseOptionalKeyword("varPtr"))) {
328  if (failed(parser.parseKeyword("var")))
329  return failure();
330  }
331  if (failed(parser.parseLParen()))
332  return failure();
333  if (failed(parser.parseOperand(var)))
334  return failure();
335 
336  return success();
337 }
338 
340  mlir::Value var) {
341  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
342  p << "varPtr(";
343  else
344  p << "var(";
345  p.printOperand(var);
346 }
347 
348 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
350  mlir::Type &accVarType) {
351  // Either `accVar` or `accPtr` keyword is required.
352  if (failed(parser.parseOptionalKeyword("accPtr"))) {
353  if (failed(parser.parseKeyword("accVar")))
354  return failure();
355  }
356  if (failed(parser.parseLParen()))
357  return failure();
358  if (failed(parser.parseOperand(var)))
359  return failure();
360  if (failed(parser.parseColon()))
361  return failure();
362  if (failed(parser.parseType(accVarType)))
363  return failure();
364  if (failed(parser.parseRParen()))
365  return failure();
366 
367  return success();
368 }
369 
371  mlir::Value accVar, mlir::Type accVarType) {
372  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
373  p << "accPtr(";
374  else
375  p << "accVar(";
376  p.printOperand(accVar);
377  p << " : ";
378  p.printType(accVarType);
379  p << ")";
380 }
381 
382 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
383  mlir::Type &varPtrType,
384  mlir::TypeAttr &varTypeAttr) {
385  if (failed(parser.parseType(varPtrType)))
386  return failure();
387  if (failed(parser.parseRParen()))
388  return failure();
389 
390  if (succeeded(parser.parseOptionalKeyword("varType"))) {
391  if (failed(parser.parseLParen()))
392  return failure();
393  mlir::Type varType;
394  if (failed(parser.parseType(varType)))
395  return failure();
396  varTypeAttr = mlir::TypeAttr::get(varType);
397  if (failed(parser.parseRParen()))
398  return failure();
399  } else {
400  // Set `varType` from the element type of the type of `varPtr`.
401  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
402  varTypeAttr = mlir::TypeAttr::get(
403  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
404  else
405  varTypeAttr = mlir::TypeAttr::get(varPtrType);
406  }
407 
408  return success();
409 }
410 
412  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
413  p.printType(varPtrType);
414  p << ")";
415 
416  // Print the `varType` only if it differs from the element type of
417  // `varPtr`'s type.
418  mlir::Type varType = varTypeAttr.getValue();
419  mlir::Type typeToCheckAgainst =
420  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
421  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
422  : varPtrType;
423  if (typeToCheckAgainst != varType) {
424  p << " varType(";
425  p.printType(varType);
426  p << ")";
427  }
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // DataBoundsOp
432 //===----------------------------------------------------------------------===//
433 LogicalResult acc::DataBoundsOp::verify() {
434  auto extent = getExtent();
435  auto upperbound = getUpperbound();
436  if (!extent && !upperbound)
437  return emitError("expected extent or upperbound.");
438  return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // PrivateOp
443 //===----------------------------------------------------------------------===//
444 LogicalResult acc::PrivateOp::verify() {
445  if (getDataClause() != acc::DataClause::acc_private)
446  return emitError(
447  "data clause associated with private operation must match its intent");
448  if (failed(checkVarAndVarType(*this)))
449  return failure();
450  return success();
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // FirstprivateOp
455 //===----------------------------------------------------------------------===//
456 LogicalResult acc::FirstprivateOp::verify() {
457  if (getDataClause() != acc::DataClause::acc_firstprivate)
458  return emitError("data clause associated with firstprivate operation must "
459  "match its intent");
460  if (failed(checkVarAndVarType(*this)))
461  return failure();
462  return success();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // ReductionOp
467 //===----------------------------------------------------------------------===//
468 LogicalResult acc::ReductionOp::verify() {
469  if (getDataClause() != acc::DataClause::acc_reduction)
470  return emitError("data clause associated with reduction operation must "
471  "match its intent");
472  if (failed(checkVarAndVarType(*this)))
473  return failure();
474  return success();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // DevicePtrOp
479 //===----------------------------------------------------------------------===//
480 LogicalResult acc::DevicePtrOp::verify() {
481  if (getDataClause() != acc::DataClause::acc_deviceptr)
482  return emitError("data clause associated with deviceptr operation must "
483  "match its intent");
484  if (failed(checkVarAndVarType(*this)))
485  return failure();
486  if (failed(checkVarAndAccVar(*this)))
487  return failure();
488  return success();
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // PresentOp
493 //===----------------------------------------------------------------------===//
494 LogicalResult acc::PresentOp::verify() {
495  if (getDataClause() != acc::DataClause::acc_present)
496  return emitError(
497  "data clause associated with present operation must match its intent");
498  if (failed(checkVarAndVarType(*this)))
499  return failure();
500  if (failed(checkVarAndAccVar(*this)))
501  return failure();
502  return success();
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // CopyinOp
507 //===----------------------------------------------------------------------===//
508 LogicalResult acc::CopyinOp::verify() {
509  // Test for all clauses this operation can be decomposed from:
510  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
511  getDataClause() != acc::DataClause::acc_copyin_readonly &&
512  getDataClause() != acc::DataClause::acc_copy &&
513  getDataClause() != acc::DataClause::acc_reduction)
514  return emitError(
515  "data clause associated with copyin operation must match its intent"
516  " or specify original clause this operation was decomposed from");
517  if (failed(checkVarAndVarType(*this)))
518  return failure();
519  if (failed(checkVarAndAccVar(*this)))
520  return failure();
521  return success();
522 }
523 
524 bool acc::CopyinOp::isCopyinReadonly() {
525  return getDataClause() == acc::DataClause::acc_copyin_readonly;
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // CreateOp
530 //===----------------------------------------------------------------------===//
531 LogicalResult acc::CreateOp::verify() {
532  // Test for all clauses this operation can be decomposed from:
533  if (getDataClause() != acc::DataClause::acc_create &&
534  getDataClause() != acc::DataClause::acc_create_zero &&
535  getDataClause() != acc::DataClause::acc_copyout &&
536  getDataClause() != acc::DataClause::acc_copyout_zero)
537  return emitError(
538  "data clause associated with create operation must match its intent"
539  " or specify original clause this operation was decomposed from");
540  if (failed(checkVarAndVarType(*this)))
541  return failure();
542  if (failed(checkVarAndAccVar(*this)))
543  return failure();
544  return success();
545 }
546 
547 bool acc::CreateOp::isCreateZero() {
548  // The zero modifier is encoded in the data clause.
549  return getDataClause() == acc::DataClause::acc_create_zero ||
550  getDataClause() == acc::DataClause::acc_copyout_zero;
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // NoCreateOp
555 //===----------------------------------------------------------------------===//
556 LogicalResult acc::NoCreateOp::verify() {
557  if (getDataClause() != acc::DataClause::acc_no_create)
558  return emitError("data clause associated with no_create operation must "
559  "match its intent");
560  if (failed(checkVarAndVarType(*this)))
561  return failure();
562  if (failed(checkVarAndAccVar(*this)))
563  return failure();
564  return success();
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // AttachOp
569 //===----------------------------------------------------------------------===//
570 LogicalResult acc::AttachOp::verify() {
571  if (getDataClause() != acc::DataClause::acc_attach)
572  return emitError(
573  "data clause associated with attach operation must match its intent");
574  if (failed(checkVarAndVarType(*this)))
575  return failure();
576  if (failed(checkVarAndAccVar(*this)))
577  return failure();
578  return success();
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // DeclareDeviceResidentOp
583 //===----------------------------------------------------------------------===//
584 
585 LogicalResult acc::DeclareDeviceResidentOp::verify() {
586  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
587  return emitError("data clause associated with device_resident operation "
588  "must match its intent");
589  if (failed(checkVarAndVarType(*this)))
590  return failure();
591  if (failed(checkVarAndAccVar(*this)))
592  return failure();
593  return success();
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // DeclareLinkOp
598 //===----------------------------------------------------------------------===//
599 
600 LogicalResult acc::DeclareLinkOp::verify() {
601  if (getDataClause() != acc::DataClause::acc_declare_link)
602  return emitError(
603  "data clause associated with link operation must match its intent");
604  if (failed(checkVarAndVarType(*this)))
605  return failure();
606  if (failed(checkVarAndAccVar(*this)))
607  return failure();
608  return success();
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // CopyoutOp
613 //===----------------------------------------------------------------------===//
614 LogicalResult acc::CopyoutOp::verify() {
615  // Test for all clauses this operation can be decomposed from:
616  if (getDataClause() != acc::DataClause::acc_copyout &&
617  getDataClause() != acc::DataClause::acc_copyout_zero &&
618  getDataClause() != acc::DataClause::acc_copy &&
619  getDataClause() != acc::DataClause::acc_reduction)
620  return emitError(
621  "data clause associated with copyout operation must match its intent"
622  " or specify original clause this operation was decomposed from");
623  if (!getVar() || !getAccVar())
624  return emitError("must have both host and device pointers");
625  if (failed(checkVarAndVarType(*this)))
626  return failure();
627  if (failed(checkVarAndAccVar(*this)))
628  return failure();
629  return success();
630 }
631 
632 bool acc::CopyoutOp::isCopyoutZero() {
633  return getDataClause() == acc::DataClause::acc_copyout_zero;
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // DeleteOp
638 //===----------------------------------------------------------------------===//
639 LogicalResult acc::DeleteOp::verify() {
640  // Test for all clauses this operation can be decomposed from:
641  if (getDataClause() != acc::DataClause::acc_delete &&
642  getDataClause() != acc::DataClause::acc_create &&
643  getDataClause() != acc::DataClause::acc_create_zero &&
644  getDataClause() != acc::DataClause::acc_copyin &&
645  getDataClause() != acc::DataClause::acc_copyin_readonly &&
646  getDataClause() != acc::DataClause::acc_present &&
647  getDataClause() != acc::DataClause::acc_no_create &&
648  getDataClause() != acc::DataClause::acc_declare_device_resident &&
649  getDataClause() != acc::DataClause::acc_declare_link)
650  return emitError(
651  "data clause associated with delete operation must match its intent"
652  " or specify original clause this operation was decomposed from");
653  if (!getAccVar())
654  return emitError("must have device pointer");
655  return success();
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // DetachOp
660 //===----------------------------------------------------------------------===//
661 LogicalResult acc::DetachOp::verify() {
662  // Test for all clauses this operation can be decomposed from:
663  if (getDataClause() != acc::DataClause::acc_detach &&
664  getDataClause() != acc::DataClause::acc_attach)
665  return emitError(
666  "data clause associated with detach operation must match its intent"
667  " or specify original clause this operation was decomposed from");
668  if (!getAccVar())
669  return emitError("must have device pointer");
670  return success();
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // HostOp
675 //===----------------------------------------------------------------------===//
676 LogicalResult acc::UpdateHostOp::verify() {
677  // Test for all clauses this operation can be decomposed from:
678  if (getDataClause() != acc::DataClause::acc_update_host &&
679  getDataClause() != acc::DataClause::acc_update_self)
680  return emitError(
681  "data clause associated with host operation must match its intent"
682  " or specify original clause this operation was decomposed from");
683  if (!getVar() || !getAccVar())
684  return emitError("must have both host and device pointers");
685  if (failed(checkVarAndVarType(*this)))
686  return failure();
687  if (failed(checkVarAndAccVar(*this)))
688  return failure();
689  return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // DeviceOp
694 //===----------------------------------------------------------------------===//
695 LogicalResult acc::UpdateDeviceOp::verify() {
696  // Test for all clauses this operation can be decomposed from:
697  if (getDataClause() != acc::DataClause::acc_update_device)
698  return emitError(
699  "data clause associated with device operation must match its intent"
700  " or specify original clause this operation was decomposed from");
701  if (failed(checkVarAndVarType(*this)))
702  return failure();
703  if (failed(checkVarAndAccVar(*this)))
704  return failure();
705  return success();
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // UseDeviceOp
710 //===----------------------------------------------------------------------===//
711 LogicalResult acc::UseDeviceOp::verify() {
712  // Test for all clauses this operation can be decomposed from:
713  if (getDataClause() != acc::DataClause::acc_use_device)
714  return emitError(
715  "data clause associated with use_device operation must match its intent"
716  " or specify original clause this operation was decomposed from");
717  if (failed(checkVarAndVarType(*this)))
718  return failure();
719  if (failed(checkVarAndAccVar(*this)))
720  return failure();
721  return success();
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // CacheOp
726 //===----------------------------------------------------------------------===//
727 LogicalResult acc::CacheOp::verify() {
728  // Test for all clauses this operation can be decomposed from:
729  if (getDataClause() != acc::DataClause::acc_cache &&
730  getDataClause() != acc::DataClause::acc_cache_readonly)
731  return emitError(
732  "data clause associated with cache operation must match its intent"
733  " or specify original clause this operation was decomposed from");
734  if (failed(checkVarAndVarType(*this)))
735  return failure();
736  if (failed(checkVarAndAccVar(*this)))
737  return failure();
738  return success();
739 }
740 
741 template <typename StructureOp>
742 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
743  unsigned nRegions = 1) {
744 
745  SmallVector<Region *, 2> regions;
746  for (unsigned i = 0; i < nRegions; ++i)
747  regions.push_back(state.addRegion());
748 
749  for (Region *region : regions)
750  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
751  return failure();
752 
753  return success();
754 }
755 
756 static bool isComputeOperation(Operation *op) {
757  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
758 }
759 
760 namespace {
761 /// Pattern to remove operation without region that have constant false `ifCond`
762 /// and remove the condition from the operation if the `ifCond` is a true
763 /// constant.
764 template <typename OpTy>
765 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
767 
768  LogicalResult matchAndRewrite(OpTy op,
769  PatternRewriter &rewriter) const override {
770  // Early return if there is no condition.
771  Value ifCond = op.getIfCond();
772  if (!ifCond)
773  return failure();
774 
775  IntegerAttr constAttr;
776  if (!matchPattern(ifCond, m_Constant(&constAttr)))
777  return failure();
778  if (constAttr.getInt())
779  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
780  else
781  rewriter.eraseOp(op);
782 
783  return success();
784  }
785 };
786 
787 /// Replaces the given op with the contents of the given single-block region,
788 /// using the operands of the block terminator to replace operation results.
789 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
790  Region &region, ValueRange blockArgs = {}) {
791  assert(llvm::hasSingleElement(region) && "expected single-region block");
792  Block *block = &region.front();
793  Operation *terminator = block->getTerminator();
794  ValueRange results = terminator->getOperands();
795  rewriter.inlineBlockBefore(block, op, blockArgs);
796  rewriter.replaceOp(op, results);
797  rewriter.eraseOp(terminator);
798 }
799 
800 /// Pattern to remove operation with region that have constant false `ifCond`
801 /// and remove the condition from the operation if the `ifCond` is constant
802 /// true.
803 template <typename OpTy>
804 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
806 
807  LogicalResult matchAndRewrite(OpTy op,
808  PatternRewriter &rewriter) const override {
809  // Early return if there is no condition.
810  Value ifCond = op.getIfCond();
811  if (!ifCond)
812  return failure();
813 
814  IntegerAttr constAttr;
815  if (!matchPattern(ifCond, m_Constant(&constAttr)))
816  return failure();
817  if (constAttr.getInt())
818  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
819  else
820  replaceOpWithRegion(rewriter, op, op.getRegion());
821 
822  return success();
823  }
824 };
825 
826 } // namespace
827 
828 //===----------------------------------------------------------------------===//
829 // PrivateRecipeOp
830 //===----------------------------------------------------------------------===//
831 
832 static LogicalResult verifyInitLikeSingleArgRegion(
833  Operation *op, Region &region, StringRef regionType, StringRef regionName,
834  Type type, bool verifyYield, bool optional = false) {
835  if (optional && region.empty())
836  return success();
837 
838  if (region.empty())
839  return op->emitOpError() << "expects non-empty " << regionName << " region";
840  Block &firstBlock = region.front();
841  if (firstBlock.getNumArguments() < 1 ||
842  firstBlock.getArgument(0).getType() != type)
843  return op->emitOpError() << "expects " << regionName
844  << " region first "
845  "argument of the "
846  << regionType << " type";
847 
848  if (verifyYield) {
849  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
850  if (yieldOp.getOperands().size() != 1 ||
851  yieldOp.getOperands().getTypes()[0] != type)
852  return op->emitOpError() << "expects " << regionName
853  << " region to "
854  "yield a value of the "
855  << regionType << " type";
856  }
857  }
858  return success();
859 }
860 
861 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
862  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
863  "privatization", "init", getType(),
864  /*verifyYield=*/false)))
865  return failure();
867  *this, getDestroyRegion(), "privatization", "destroy", getType(),
868  /*verifyYield=*/false, /*optional=*/true)))
869  return failure();
870  return success();
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // FirstprivateRecipeOp
875 //===----------------------------------------------------------------------===//
876 
877 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
878  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
879  "privatization", "init", getType(),
880  /*verifyYield=*/false)))
881  return failure();
882 
883  if (getCopyRegion().empty())
884  return emitOpError() << "expects non-empty copy region";
885 
886  Block &firstBlock = getCopyRegion().front();
887  if (firstBlock.getNumArguments() < 2 ||
888  firstBlock.getArgument(0).getType() != getType())
889  return emitOpError() << "expects copy region with two arguments of the "
890  "privatization type";
891 
892  if (getDestroyRegion().empty())
893  return success();
894 
895  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
896  "privatization", "destroy",
897  getType(), /*verifyYield=*/false)))
898  return failure();
899 
900  return success();
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // ReductionRecipeOp
905 //===----------------------------------------------------------------------===//
906 
907 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
908  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
909  "init", getType(),
910  /*verifyYield=*/false)))
911  return failure();
912 
913  if (getCombinerRegion().empty())
914  return emitOpError() << "expects non-empty combiner region";
915 
916  Block &reductionBlock = getCombinerRegion().front();
917  if (reductionBlock.getNumArguments() < 2 ||
918  reductionBlock.getArgument(0).getType() != getType() ||
919  reductionBlock.getArgument(1).getType() != getType())
920  return emitOpError() << "expects combiner region with the first two "
921  << "arguments of the reduction type";
922 
923  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
924  if (yieldOp.getOperands().size() != 1 ||
925  yieldOp.getOperands().getTypes()[0] != getType())
926  return emitOpError() << "expects combiner region to yield a value "
927  "of the reduction type";
928  }
929 
930  return success();
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // Custom parser and printer verifier for private clause
935 //===----------------------------------------------------------------------===//
936 
937 static ParseResult parseSymOperandList(
938  mlir::OpAsmParser &parser,
940  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
942  if (failed(parser.parseCommaSeparatedList([&]() {
943  if (parser.parseAttribute(attributes.emplace_back()) ||
944  parser.parseArrow() ||
945  parser.parseOperand(operands.emplace_back()) ||
946  parser.parseColonType(types.emplace_back()))
947  return failure();
948  return success();
949  })))
950  return failure();
951  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
952  attributes.end());
953  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
954  return success();
955 }
956 
958  mlir::OperandRange operands,
959  mlir::TypeRange types,
960  std::optional<mlir::ArrayAttr> attributes) {
961  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
962  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
963  << std::get<1>(it).getType();
964  });
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // ParallelOp
969 //===----------------------------------------------------------------------===//
970 
971 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
972 template <typename Op>
973 static LogicalResult checkDataOperands(Op op,
974  const mlir::ValueRange &operands) {
975  for (mlir::Value operand : operands)
976  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
977  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
978  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
979  operand.getDefiningOp()))
980  return op.emitError(
981  "expect data entry/exit operation or acc.getdeviceptr "
982  "as defining op");
983  return success();
984 }
985 
986 template <typename Op>
987 static LogicalResult
988 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
989  mlir::OperandRange operands, llvm::StringRef operandName,
990  llvm::StringRef symbolName, bool checkOperandType = true) {
991  if (!operands.empty()) {
992  if (!attributes || attributes->size() != operands.size())
993  return op->emitOpError()
994  << "expected as many " << symbolName << " symbol reference as "
995  << operandName << " operands";
996  } else {
997  if (attributes)
998  return op->emitOpError()
999  << "unexpected " << symbolName << " symbol reference";
1000  return success();
1001  }
1002 
1004  for (auto args : llvm::zip(operands, *attributes)) {
1005  mlir::Value operand = std::get<0>(args);
1006 
1007  if (!set.insert(operand).second)
1008  return op->emitOpError()
1009  << operandName << " operand appears more than once";
1010 
1011  mlir::Type varType = operand.getType();
1012  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1013  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1014  if (!decl)
1015  return op->emitOpError()
1016  << "expected symbol reference " << symbolRef << " to point to a "
1017  << operandName << " declaration";
1018 
1019  if (checkOperandType && decl.getType() && decl.getType() != varType)
1020  return op->emitOpError() << "expected " << operandName << " (" << varType
1021  << ") to be the same type as " << operandName
1022  << " declaration (" << decl.getType() << ")";
1023  }
1024 
1025  return success();
1026 }
1027 
1028 unsigned ParallelOp::getNumDataOperands() {
1029  return getReductionOperands().size() + getPrivateOperands().size() +
1030  getFirstprivateOperands().size() + getDataClauseOperands().size();
1031 }
1032 
1033 Value ParallelOp::getDataOperand(unsigned i) {
1034  unsigned numOptional = getAsyncOperands().size();
1035  numOptional += getNumGangs().size();
1036  numOptional += getNumWorkers().size();
1037  numOptional += getVectorLength().size();
1038  numOptional += getIfCond() ? 1 : 0;
1039  numOptional += getSelfCond() ? 1 : 0;
1040  return getOperand(getWaitOperands().size() + numOptional + i);
1041 }
1042 
1043 template <typename Op>
1044 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1045  ArrayAttr deviceTypes,
1046  llvm::StringRef keyword) {
1047  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1048  return op.emitOpError() << keyword << " operands count must match "
1049  << keyword << " device_type count";
1050  return success();
1051 }
1052 
1053 template <typename Op>
1055  Op op, OperandRange operands, DenseI32ArrayAttr segments,
1056  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1057  std::size_t numOperandsInSegments = 0;
1058  std::size_t nbOfSegments = 0;
1059 
1060  if (segments) {
1061  for (auto segCount : segments.asArrayRef()) {
1062  if (maxInSegment != 0 && segCount > maxInSegment)
1063  return op.emitOpError() << keyword << " expects a maximum of "
1064  << maxInSegment << " values per segment";
1065  numOperandsInSegments += segCount;
1066  ++nbOfSegments;
1067  }
1068  }
1069 
1070  if ((numOperandsInSegments != operands.size()) ||
1071  (!deviceTypes && !operands.empty()))
1072  return op.emitOpError()
1073  << keyword << " operand count does not match count in segments";
1074  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1075  return op.emitOpError()
1076  << keyword << " segment count does not match device_type count";
1077  return success();
1078 }
1079 
1080 LogicalResult acc::ParallelOp::verify() {
1081  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1082  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1083  "privatizations", /*checkOperandType=*/false)))
1084  return failure();
1085  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1086  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1087  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1088  return failure();
1089  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1090  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1091  "reductions", false)))
1092  return failure();
1093 
1095  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1096  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1097  return failure();
1098 
1100  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1101  getWaitOperandsDeviceTypeAttr(), "wait")))
1102  return failure();
1103 
1104  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1105  getNumWorkersDeviceTypeAttr(),
1106  "num_workers")))
1107  return failure();
1108 
1109  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1110  getVectorLengthDeviceTypeAttr(),
1111  "vector_length")))
1112  return failure();
1113 
1114  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1115  getAsyncOperandsDeviceTypeAttr(),
1116  "async")))
1117  return failure();
1118 
1119  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1120  return failure();
1121 
1122  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1123 }
1124 
1125 static mlir::Value
1126 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1128  mlir::acc::DeviceType deviceType) {
1129  if (!arrayAttr)
1130  return {};
1131  if (auto pos = findSegment(*arrayAttr, deviceType))
1132  return range[*pos];
1133  return {};
1134 }
1135 
1136 bool acc::ParallelOp::hasAsyncOnly() {
1137  return hasAsyncOnly(mlir::acc::DeviceType::None);
1138 }
1139 
1140 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1141  return hasDeviceType(getAsyncOnly(), deviceType);
1142 }
1143 
1144 mlir::Value acc::ParallelOp::getAsyncValue() {
1145  return getAsyncValue(mlir::acc::DeviceType::None);
1146 }
1147 
1148 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1150  getAsyncOperands(), deviceType);
1151 }
1152 
1153 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1154  return getNumWorkersValue(mlir::acc::DeviceType::None);
1155 }
1156 
1158 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1159  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1160  deviceType);
1161 }
1162 
1163 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1164  return getVectorLengthValue(mlir::acc::DeviceType::None);
1165 }
1166 
1168 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1169  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1170  getVectorLength(), deviceType);
1171 }
1172 
1173 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1174  return getNumGangsValues(mlir::acc::DeviceType::None);
1175 }
1176 
1178 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1179  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1180  getNumGangsSegments(), deviceType);
1181 }
1182 
1183 bool acc::ParallelOp::hasWaitOnly() {
1184  return hasWaitOnly(mlir::acc::DeviceType::None);
1185 }
1186 
1187 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1188  return hasDeviceType(getWaitOnly(), deviceType);
1189 }
1190 
1191 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1192  return getWaitValues(mlir::acc::DeviceType::None);
1193 }
1194 
1196 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1198  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1199  getHasWaitDevnum(), deviceType);
1200 }
1201 
1202 mlir::Value ParallelOp::getWaitDevnum() {
1203  return getWaitDevnum(mlir::acc::DeviceType::None);
1204 }
1205 
1206 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1207  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1208  getWaitOperandsSegments(), getHasWaitDevnum(),
1209  deviceType);
1210 }
1211 
1212 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1213  mlir::OperationState &odsState,
1214  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1215  mlir::ValueRange vectorLength,
1216  mlir::ValueRange asyncOperands,
1217  mlir::ValueRange waitOperands, mlir::Value ifCond,
1218  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1219  mlir::ValueRange gangPrivateOperands,
1220  mlir::ValueRange gangFirstPrivateOperands,
1221  mlir::ValueRange dataClauseOperands) {
1222 
1223  ParallelOp::build(
1224  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1225  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1226  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1227  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1228  /*numGangsDeviceType=*/nullptr, numWorkers,
1229  /*numWorkersDeviceType=*/nullptr, vectorLength,
1230  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1231  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1232  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1233  /*firstprivatizations=*/nullptr, dataClauseOperands,
1234  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1235 }
1236 
1237 void acc::ParallelOp::addNumWorkersOperand(
1238  MLIRContext *context, mlir::Value newValue,
1239  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1240  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1241  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1242  getNumWorkersMutable()));
1243 }
1244 void acc::ParallelOp::addVectorLengthOperand(
1245  MLIRContext *context, mlir::Value newValue,
1246  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1247  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1248  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1249  getVectorLengthMutable()));
1250 }
1251 
1252 void acc::ParallelOp::addAsyncOnly(
1253  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1254  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1255  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1256 }
1257 
1258 void acc::ParallelOp::addAsyncOperand(
1259  MLIRContext *context, mlir::Value newValue,
1260  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1261  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1262  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1263  getAsyncOperandsMutable()));
1264 }
1265 
1266 void acc::ParallelOp::addNumGangsOperands(
1267  MLIRContext *context, mlir::ValueRange newValues,
1268  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1269  llvm::SmallVector<int32_t> segments;
1270  if (getNumGangsSegments())
1271  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1272 
1273  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1274  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1275  getNumGangsMutable(), segments));
1276 
1277  setNumGangsSegments(segments);
1278 }
1279 void acc::ParallelOp::addWaitOnly(
1280  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1281  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1282  effectiveDeviceTypes));
1283 }
1284 void acc::ParallelOp::addWaitOperands(
1285  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1286  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1287 
1288  llvm::SmallVector<int32_t> segments;
1289  if (getWaitOperandsSegments())
1290  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1291 
1292  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1293  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1294  getWaitOperandsMutable(), segments));
1295  setWaitOperandsSegments(segments);
1296 
1298  if (getHasWaitDevnumAttr())
1299  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1300  hasDevnums.insert(
1301  hasDevnums.end(),
1302  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1303  mlir::BoolAttr::get(context, hasDevnum));
1304  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1305 }
1306 
1307 static ParseResult parseNumGangs(
1308  mlir::OpAsmParser &parser,
1310  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1311  mlir::DenseI32ArrayAttr &segments) {
1314 
1315  do {
1316  if (failed(parser.parseLBrace()))
1317  return failure();
1318 
1319  int32_t crtOperandsSize = operands.size();
1320  if (failed(parser.parseCommaSeparatedList(
1322  if (parser.parseOperand(operands.emplace_back()) ||
1323  parser.parseColonType(types.emplace_back()))
1324  return failure();
1325  return success();
1326  })))
1327  return failure();
1328  seg.push_back(operands.size() - crtOperandsSize);
1329 
1330  if (failed(parser.parseRBrace()))
1331  return failure();
1332 
1333  if (succeeded(parser.parseOptionalLSquare())) {
1334  if (parser.parseAttribute(attributes.emplace_back()) ||
1335  parser.parseRSquare())
1336  return failure();
1337  } else {
1338  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1340  }
1341  } while (succeeded(parser.parseOptionalComma()));
1342 
1343  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1344  attributes.end());
1345  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1346  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1347 
1348  return success();
1349 }
1350 
1352  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1353  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1354  p << " [" << attr << "]";
1355 }
1356 
1358  mlir::OperandRange operands, mlir::TypeRange types,
1359  std::optional<mlir::ArrayAttr> deviceTypes,
1360  std::optional<mlir::DenseI32ArrayAttr> segments) {
1361  unsigned opIdx = 0;
1362  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1363  p << "{";
1364  llvm::interleaveComma(
1365  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1366  p << operands[opIdx] << " : " << operands[opIdx].getType();
1367  ++opIdx;
1368  });
1369  p << "}";
1370  printSingleDeviceType(p, it.value());
1371  });
1372 }
1373 
1375  mlir::OpAsmParser &parser,
1377  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1378  mlir::DenseI32ArrayAttr &segments) {
1381 
1382  do {
1383  if (failed(parser.parseLBrace()))
1384  return failure();
1385 
1386  int32_t crtOperandsSize = operands.size();
1387 
1388  if (failed(parser.parseCommaSeparatedList(
1390  if (parser.parseOperand(operands.emplace_back()) ||
1391  parser.parseColonType(types.emplace_back()))
1392  return failure();
1393  return success();
1394  })))
1395  return failure();
1396 
1397  seg.push_back(operands.size() - crtOperandsSize);
1398 
1399  if (failed(parser.parseRBrace()))
1400  return failure();
1401 
1402  if (succeeded(parser.parseOptionalLSquare())) {
1403  if (parser.parseAttribute(attributes.emplace_back()) ||
1404  parser.parseRSquare())
1405  return failure();
1406  } else {
1407  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1409  }
1410  } while (succeeded(parser.parseOptionalComma()));
1411 
1412  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1413  attributes.end());
1414  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1415  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1416 
1417  return success();
1418 }
1419 
1422  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1423  std::optional<mlir::DenseI32ArrayAttr> segments) {
1424  unsigned opIdx = 0;
1425  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1426  p << "{";
1427  llvm::interleaveComma(
1428  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1429  p << operands[opIdx] << " : " << operands[opIdx].getType();
1430  ++opIdx;
1431  });
1432  p << "}";
1433  printSingleDeviceType(p, it.value());
1434  });
1435 }
1436 
1437 static ParseResult parseWaitClause(
1438  mlir::OpAsmParser &parser,
1440  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1441  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1442  mlir::ArrayAttr &keywordOnly) {
1443  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1445 
1446  bool needCommaBeforeOperands = false;
1447 
1448  // Keyword only
1449  if (failed(parser.parseOptionalLParen())) {
1450  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1452  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1453  return success();
1454  }
1455 
1456  // Parse keyword only attributes
1457  if (succeeded(parser.parseOptionalLSquare())) {
1458  if (failed(parser.parseCommaSeparatedList([&]() {
1459  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1460  return failure();
1461  return success();
1462  })))
1463  return failure();
1464  if (parser.parseRSquare())
1465  return failure();
1466  needCommaBeforeOperands = true;
1467  }
1468 
1469  if (needCommaBeforeOperands && failed(parser.parseComma()))
1470  return failure();
1471 
1472  do {
1473  if (failed(parser.parseLBrace()))
1474  return failure();
1475 
1476  int32_t crtOperandsSize = operands.size();
1477 
1478  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1479  if (failed(parser.parseColon()))
1480  return failure();
1481  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1482  } else {
1483  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1484  }
1485 
1486  if (failed(parser.parseCommaSeparatedList(
1488  if (parser.parseOperand(operands.emplace_back()) ||
1489  parser.parseColonType(types.emplace_back()))
1490  return failure();
1491  return success();
1492  })))
1493  return failure();
1494 
1495  seg.push_back(operands.size() - crtOperandsSize);
1496 
1497  if (failed(parser.parseRBrace()))
1498  return failure();
1499 
1500  if (succeeded(parser.parseOptionalLSquare())) {
1501  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1502  parser.parseRSquare())
1503  return failure();
1504  } else {
1505  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1507  }
1508  } while (succeeded(parser.parseOptionalComma()));
1509 
1510  if (failed(parser.parseRParen()))
1511  return failure();
1512 
1513  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1514  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1515  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1516  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1517 
1518  return success();
1519 }
1520 
1521 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1522  if (!hasDeviceTypeValues(attrs))
1523  return false;
1524  if (attrs->size() != 1)
1525  return false;
1526  if (auto deviceTypeAttr =
1527  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1528  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1529  return false;
1530 }
1531 
1533  mlir::OperandRange operands, mlir::TypeRange types,
1534  std::optional<mlir::ArrayAttr> deviceTypes,
1535  std::optional<mlir::DenseI32ArrayAttr> segments,
1536  std::optional<mlir::ArrayAttr> hasDevNum,
1537  std::optional<mlir::ArrayAttr> keywordOnly) {
1538 
1539  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1540  return;
1541 
1542  p << "(";
1543 
1544  printDeviceTypes(p, keywordOnly);
1545  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1546  p << ", ";
1547 
1548  if (hasDeviceTypeValues(deviceTypes)) {
1549  unsigned opIdx = 0;
1550  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1551  p << "{";
1552  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1553  if (boolAttr && boolAttr.getValue())
1554  p << "devnum: ";
1555  llvm::interleaveComma(
1556  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1557  p << operands[opIdx] << " : " << operands[opIdx].getType();
1558  ++opIdx;
1559  });
1560  p << "}";
1561  printSingleDeviceType(p, it.value());
1562  });
1563  }
1564 
1565  p << ")";
1566 }
1567 
1568 static ParseResult parseDeviceTypeOperands(
1569  mlir::OpAsmParser &parser,
1571  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1573  if (failed(parser.parseCommaSeparatedList([&]() {
1574  if (parser.parseOperand(operands.emplace_back()) ||
1575  parser.parseColonType(types.emplace_back()))
1576  return failure();
1577  if (succeeded(parser.parseOptionalLSquare())) {
1578  if (parser.parseAttribute(attributes.emplace_back()) ||
1579  parser.parseRSquare())
1580  return failure();
1581  } else {
1582  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1583  parser.getContext(), mlir::acc::DeviceType::None));
1584  }
1585  return success();
1586  })))
1587  return failure();
1588  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1589  attributes.end());
1590  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1591  return success();
1592 }
1593 
1594 static void
1596  mlir::OperandRange operands, mlir::TypeRange types,
1597  std::optional<mlir::ArrayAttr> deviceTypes) {
1598  if (!hasDeviceTypeValues(deviceTypes))
1599  return;
1600  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1601  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1602  printSingleDeviceType(p, std::get<0>(it));
1603  });
1604 }
1605 
1607  mlir::OpAsmParser &parser,
1609  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1610  mlir::ArrayAttr &keywordOnlyDeviceType) {
1611 
1612  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1613  bool needCommaBeforeOperands = false;
1614 
1615  if (failed(parser.parseOptionalLParen())) {
1616  // Keyword only
1617  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1619  keywordOnlyDeviceType =
1620  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1621  return success();
1622  }
1623 
1624  // Parse keyword only attributes
1625  if (succeeded(parser.parseOptionalLSquare())) {
1626  // Parse keyword only attributes
1627  if (failed(parser.parseCommaSeparatedList([&]() {
1628  if (parser.parseAttribute(
1629  keywordOnlyDeviceTypeAttributes.emplace_back()))
1630  return failure();
1631  return success();
1632  })))
1633  return failure();
1634  if (parser.parseRSquare())
1635  return failure();
1636  needCommaBeforeOperands = true;
1637  }
1638 
1639  if (needCommaBeforeOperands && failed(parser.parseComma()))
1640  return failure();
1641 
1643  if (failed(parser.parseCommaSeparatedList([&]() {
1644  if (parser.parseOperand(operands.emplace_back()) ||
1645  parser.parseColonType(types.emplace_back()))
1646  return failure();
1647  if (succeeded(parser.parseOptionalLSquare())) {
1648  if (parser.parseAttribute(attributes.emplace_back()) ||
1649  parser.parseRSquare())
1650  return failure();
1651  } else {
1652  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1653  parser.getContext(), mlir::acc::DeviceType::None));
1654  }
1655  return success();
1656  })))
1657  return failure();
1658 
1659  if (failed(parser.parseRParen()))
1660  return failure();
1661 
1662  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1663  attributes.end());
1664  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1665  return success();
1666 }
1667 
1670  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1671  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1672 
1673  if (operands.begin() == operands.end() &&
1674  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1675  return;
1676  }
1677 
1678  p << "(";
1679  printDeviceTypes(p, keywordOnlyDeviceTypes);
1680  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1681  hasDeviceTypeValues(deviceTypes))
1682  p << ", ";
1683  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1684  p << ")";
1685 }
1686 
1687 static ParseResult parseOperandWithKeywordOnly(
1688  mlir::OpAsmParser &parser,
1689  std::optional<OpAsmParser::UnresolvedOperand> &operand,
1690  mlir::Type &operandType, mlir::UnitAttr &attr) {
1691  // Keyword only
1692  if (failed(parser.parseOptionalLParen())) {
1693  attr = mlir::UnitAttr::get(parser.getContext());
1694  return success();
1695  }
1696 
1698  if (failed(parser.parseOperand(op)))
1699  return failure();
1700  operand = op;
1701  if (failed(parser.parseColon()))
1702  return failure();
1703  if (failed(parser.parseType(operandType)))
1704  return failure();
1705  if (failed(parser.parseRParen()))
1706  return failure();
1707 
1708  return success();
1709 }
1710 
1712  mlir::Operation *op,
1713  std::optional<mlir::Value> operand,
1714  mlir::Type operandType,
1715  mlir::UnitAttr attr) {
1716  if (attr)
1717  return;
1718 
1719  p << "(";
1720  p.printOperand(*operand);
1721  p << " : ";
1722  p.printType(operandType);
1723  p << ")";
1724 }
1725 
1726 static ParseResult parseOperandsWithKeywordOnly(
1727  mlir::OpAsmParser &parser,
1729  llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
1730  // Keyword only
1731  if (failed(parser.parseOptionalLParen())) {
1732  attr = mlir::UnitAttr::get(parser.getContext());
1733  return success();
1734  }
1735 
1736  if (failed(parser.parseCommaSeparatedList([&]() {
1737  if (parser.parseOperand(operands.emplace_back()))
1738  return failure();
1739  return success();
1740  })))
1741  return failure();
1742  if (failed(parser.parseColon()))
1743  return failure();
1744  if (failed(parser.parseCommaSeparatedList([&]() {
1745  if (parser.parseType(types.emplace_back()))
1746  return failure();
1747  return success();
1748  })))
1749  return failure();
1750  if (failed(parser.parseRParen()))
1751  return failure();
1752 
1753  return success();
1754 }
1755 
1757  mlir::Operation *op,
1758  mlir::OperandRange operands,
1759  mlir::TypeRange types,
1760  mlir::UnitAttr attr) {
1761  if (attr)
1762  return;
1763 
1764  p << "(";
1765  llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
1766  p << " : ";
1767  llvm::interleaveComma(types, p, [&](auto it) { p << it; });
1768  p << ")";
1769 }
1770 
1771 static ParseResult
1773  mlir::acc::CombinedConstructsTypeAttr &attr) {
1774  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1776  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1777  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1779  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1780  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1782  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1783  } else {
1784  parser.emitError(parser.getCurrentLocation(),
1785  "expected compute construct name");
1786  return failure();
1787  }
1788  return success();
1789 }
1790 
1791 static void
1793  mlir::acc::CombinedConstructsTypeAttr attr) {
1794  if (attr) {
1795  switch (attr.getValue()) {
1796  case mlir::acc::CombinedConstructsType::KernelsLoop:
1797  p << "kernels";
1798  break;
1799  case mlir::acc::CombinedConstructsType::ParallelLoop:
1800  p << "parallel";
1801  break;
1802  case mlir::acc::CombinedConstructsType::SerialLoop:
1803  p << "serial";
1804  break;
1805  };
1806  }
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // SerialOp
1811 //===----------------------------------------------------------------------===//
1812 
1813 unsigned SerialOp::getNumDataOperands() {
1814  return getReductionOperands().size() + getPrivateOperands().size() +
1815  getFirstprivateOperands().size() + getDataClauseOperands().size();
1816 }
1817 
1818 Value SerialOp::getDataOperand(unsigned i) {
1819  unsigned numOptional = getAsyncOperands().size();
1820  numOptional += getIfCond() ? 1 : 0;
1821  numOptional += getSelfCond() ? 1 : 0;
1822  return getOperand(getWaitOperands().size() + numOptional + i);
1823 }
1824 
1825 bool acc::SerialOp::hasAsyncOnly() {
1826  return hasAsyncOnly(mlir::acc::DeviceType::None);
1827 }
1828 
1829 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1830  return hasDeviceType(getAsyncOnly(), deviceType);
1831 }
1832 
1833 mlir::Value acc::SerialOp::getAsyncValue() {
1834  return getAsyncValue(mlir::acc::DeviceType::None);
1835 }
1836 
1837 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1839  getAsyncOperands(), deviceType);
1840 }
1841 
1842 bool acc::SerialOp::hasWaitOnly() {
1843  return hasWaitOnly(mlir::acc::DeviceType::None);
1844 }
1845 
1846 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1847  return hasDeviceType(getWaitOnly(), deviceType);
1848 }
1849 
1850 mlir::Operation::operand_range SerialOp::getWaitValues() {
1851  return getWaitValues(mlir::acc::DeviceType::None);
1852 }
1853 
1855 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1857  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1858  getHasWaitDevnum(), deviceType);
1859 }
1860 
1861 mlir::Value SerialOp::getWaitDevnum() {
1862  return getWaitDevnum(mlir::acc::DeviceType::None);
1863 }
1864 
1865 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1866  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1867  getWaitOperandsSegments(), getHasWaitDevnum(),
1868  deviceType);
1869 }
1870 
1871 LogicalResult acc::SerialOp::verify() {
1872  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1873  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1874  "privatizations", /*checkOperandType=*/false)))
1875  return failure();
1876  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1877  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1878  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1879  return failure();
1880  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1881  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1882  "reductions", false)))
1883  return failure();
1884 
1886  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1887  getWaitOperandsDeviceTypeAttr(), "wait")))
1888  return failure();
1889 
1890  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1891  getAsyncOperandsDeviceTypeAttr(),
1892  "async")))
1893  return failure();
1894 
1895  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1896  return failure();
1897 
1898  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1899 }
1900 
1901 void acc::SerialOp::addAsyncOnly(
1902  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1903  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1904  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1905 }
1906 
1907 void acc::SerialOp::addAsyncOperand(
1908  MLIRContext *context, mlir::Value newValue,
1909  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1910  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912  getAsyncOperandsMutable()));
1913 }
1914 
1915 void acc::SerialOp::addWaitOnly(
1916  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1917  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1918  effectiveDeviceTypes));
1919 }
1920 void acc::SerialOp::addWaitOperands(
1921  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1922  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1923 
1924  llvm::SmallVector<int32_t> segments;
1925  if (getWaitOperandsSegments())
1926  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1927 
1928  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1929  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1930  getWaitOperandsMutable(), segments));
1931  setWaitOperandsSegments(segments);
1932 
1934  if (getHasWaitDevnumAttr())
1935  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1936  hasDevnums.insert(
1937  hasDevnums.end(),
1938  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1939  mlir::BoolAttr::get(context, hasDevnum));
1940  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // KernelsOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 unsigned KernelsOp::getNumDataOperands() {
1948  return getDataClauseOperands().size();
1949 }
1950 
1951 Value KernelsOp::getDataOperand(unsigned i) {
1952  unsigned numOptional = getAsyncOperands().size();
1953  numOptional += getWaitOperands().size();
1954  numOptional += getNumGangs().size();
1955  numOptional += getNumWorkers().size();
1956  numOptional += getVectorLength().size();
1957  numOptional += getIfCond() ? 1 : 0;
1958  numOptional += getSelfCond() ? 1 : 0;
1959  return getOperand(numOptional + i);
1960 }
1961 
1962 bool acc::KernelsOp::hasAsyncOnly() {
1963  return hasAsyncOnly(mlir::acc::DeviceType::None);
1964 }
1965 
1966 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1967  return hasDeviceType(getAsyncOnly(), deviceType);
1968 }
1969 
1970 mlir::Value acc::KernelsOp::getAsyncValue() {
1971  return getAsyncValue(mlir::acc::DeviceType::None);
1972 }
1973 
1974 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1976  getAsyncOperands(), deviceType);
1977 }
1978 
1979 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1980  return getNumWorkersValue(mlir::acc::DeviceType::None);
1981 }
1982 
1984 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1985  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1986  deviceType);
1987 }
1988 
1989 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1990  return getVectorLengthValue(mlir::acc::DeviceType::None);
1991 }
1992 
1994 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1995  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1996  getVectorLength(), deviceType);
1997 }
1998 
1999 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2000  return getNumGangsValues(mlir::acc::DeviceType::None);
2001 }
2002 
2004 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2005  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2006  getNumGangsSegments(), deviceType);
2007 }
2008 
2009 bool acc::KernelsOp::hasWaitOnly() {
2010  return hasWaitOnly(mlir::acc::DeviceType::None);
2011 }
2012 
2013 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2014  return hasDeviceType(getWaitOnly(), deviceType);
2015 }
2016 
2017 mlir::Operation::operand_range KernelsOp::getWaitValues() {
2018  return getWaitValues(mlir::acc::DeviceType::None);
2019 }
2020 
2022 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2024  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2025  getHasWaitDevnum(), deviceType);
2026 }
2027 
2028 mlir::Value KernelsOp::getWaitDevnum() {
2029  return getWaitDevnum(mlir::acc::DeviceType::None);
2030 }
2031 
2032 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2033  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2034  getWaitOperandsSegments(), getHasWaitDevnum(),
2035  deviceType);
2036 }
2037 
2038 LogicalResult acc::KernelsOp::verify() {
2040  *this, getNumGangs(), getNumGangsSegmentsAttr(),
2041  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2042  return failure();
2043 
2045  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2046  getWaitOperandsDeviceTypeAttr(), "wait")))
2047  return failure();
2048 
2049  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2050  getNumWorkersDeviceTypeAttr(),
2051  "num_workers")))
2052  return failure();
2053 
2054  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2055  getVectorLengthDeviceTypeAttr(),
2056  "vector_length")))
2057  return failure();
2058 
2059  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2060  getAsyncOperandsDeviceTypeAttr(),
2061  "async")))
2062  return failure();
2063 
2064  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
2065  return failure();
2066 
2067  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2068 }
2069 
2070 void acc::KernelsOp::addNumWorkersOperand(
2071  MLIRContext *context, mlir::Value newValue,
2072  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2073  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2074  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2075  getNumWorkersMutable()));
2076 }
2077 
2078 void acc::KernelsOp::addVectorLengthOperand(
2079  MLIRContext *context, mlir::Value newValue,
2080  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2081  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2082  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2083  getVectorLengthMutable()));
2084 }
2085 void acc::KernelsOp::addAsyncOnly(
2086  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2087  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2088  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2089 }
2090 
2091 void acc::KernelsOp::addAsyncOperand(
2092  MLIRContext *context, mlir::Value newValue,
2093  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2094  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2095  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2096  getAsyncOperandsMutable()));
2097 }
2098 
2099 void acc::KernelsOp::addNumGangsOperands(
2100  MLIRContext *context, mlir::ValueRange newValues,
2101  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2102  llvm::SmallVector<int32_t> segments;
2103  if (getNumGangsSegmentsAttr())
2104  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2105 
2106  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2107  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2108  getNumGangsMutable(), segments));
2109 
2110  setNumGangsSegments(segments);
2111 }
2112 
2113 void acc::KernelsOp::addWaitOnly(
2114  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2115  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2116  effectiveDeviceTypes));
2117 }
2118 void acc::KernelsOp::addWaitOperands(
2119  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2120  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2121 
2122  llvm::SmallVector<int32_t> segments;
2123  if (getWaitOperandsSegments())
2124  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2125 
2126  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2127  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2128  getWaitOperandsMutable(), segments));
2129  setWaitOperandsSegments(segments);
2130 
2132  if (getHasWaitDevnumAttr())
2133  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2134  hasDevnums.insert(
2135  hasDevnums.end(),
2136  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2137  mlir::BoolAttr::get(context, hasDevnum));
2138  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2139 }
2140 
2141 //===----------------------------------------------------------------------===//
2142 // HostDataOp
2143 //===----------------------------------------------------------------------===//
2144 
2145 LogicalResult acc::HostDataOp::verify() {
2146  if (getDataClauseOperands().empty())
2147  return emitError("at least one operand must appear on the host_data "
2148  "operation");
2149 
2150  for (mlir::Value operand : getDataClauseOperands())
2151  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2152  return emitError("expect data entry operation as defining op");
2153  return success();
2154 }
2155 
2156 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2157  MLIRContext *context) {
2158  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2159 }
2160 
2161 //===----------------------------------------------------------------------===//
2162 // LoopOp
2163 //===----------------------------------------------------------------------===//
2164 
2165 static ParseResult parseGangValue(
2166  OpAsmParser &parser, llvm::StringRef keyword,
2169  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2170  bool &needCommaBetweenValues, bool &newValue) {
2171  if (succeeded(parser.parseOptionalKeyword(keyword))) {
2172  if (parser.parseEqual())
2173  return failure();
2174  if (parser.parseOperand(operands.emplace_back()) ||
2175  parser.parseColonType(types.emplace_back()))
2176  return failure();
2177  attributes.push_back(gangArgType);
2178  needCommaBetweenValues = true;
2179  newValue = true;
2180  }
2181  return success();
2182 }
2183 
2184 static ParseResult parseGangClause(
2185  OpAsmParser &parser,
2187  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2188  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2189  mlir::ArrayAttr &gangOnlyDeviceType) {
2190  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2191  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2192  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2194  bool needCommaBetweenValues = false;
2195  bool needCommaBeforeOperands = false;
2196 
2197  if (failed(parser.parseOptionalLParen())) {
2198  // Gang only keyword
2199  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2201  gangOnlyDeviceType =
2202  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2203  return success();
2204  }
2205 
2206  // Parse gang only attributes
2207  if (succeeded(parser.parseOptionalLSquare())) {
2208  // Parse gang only attributes
2209  if (failed(parser.parseCommaSeparatedList([&]() {
2210  if (parser.parseAttribute(
2211  gangOnlyDeviceTypeAttributes.emplace_back()))
2212  return failure();
2213  return success();
2214  })))
2215  return failure();
2216  if (parser.parseRSquare())
2217  return failure();
2218  needCommaBeforeOperands = true;
2219  }
2220 
2221  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2222  mlir::acc::GangArgType::Num);
2223  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2224  mlir::acc::GangArgType::Dim);
2225  auto argStatic = mlir::acc::GangArgTypeAttr::get(
2226  parser.getContext(), mlir::acc::GangArgType::Static);
2227 
2228  do {
2229  if (needCommaBeforeOperands) {
2230  needCommaBeforeOperands = false;
2231  continue;
2232  }
2233 
2234  if (failed(parser.parseLBrace()))
2235  return failure();
2236 
2237  int32_t crtOperandsSize = gangOperands.size();
2238  while (true) {
2239  bool newValue = false;
2240  bool needValue = false;
2241  if (needCommaBetweenValues) {
2242  if (succeeded(parser.parseOptionalComma()))
2243  needValue = true; // expect a new value after comma.
2244  else
2245  break;
2246  }
2247 
2248  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2249  gangOperands, gangOperandsType,
2250  gangArgTypeAttributes, argNum,
2251  needCommaBetweenValues, newValue)))
2252  return failure();
2253  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2254  gangOperands, gangOperandsType,
2255  gangArgTypeAttributes, argDim,
2256  needCommaBetweenValues, newValue)))
2257  return failure();
2258  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2259  gangOperands, gangOperandsType,
2260  gangArgTypeAttributes, argStatic,
2261  needCommaBetweenValues, newValue)))
2262  return failure();
2263 
2264  if (!newValue && needValue) {
2265  parser.emitError(parser.getCurrentLocation(),
2266  "new value expected after comma");
2267  return failure();
2268  }
2269 
2270  if (!newValue)
2271  break;
2272  }
2273 
2274  if (gangOperands.empty())
2275  return parser.emitError(
2276  parser.getCurrentLocation(),
2277  "expect at least one of num, dim or static values");
2278 
2279  if (failed(parser.parseRBrace()))
2280  return failure();
2281 
2282  if (succeeded(parser.parseOptionalLSquare())) {
2283  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2284  parser.parseRSquare())
2285  return failure();
2286  } else {
2287  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2289  }
2290 
2291  seg.push_back(gangOperands.size() - crtOperandsSize);
2292 
2293  } while (succeeded(parser.parseOptionalComma()));
2294 
2295  if (failed(parser.parseRParen()))
2296  return failure();
2297 
2298  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2299  gangArgTypeAttributes.end());
2300  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2301  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2302 
2304  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2305  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2306 
2307  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2308  return success();
2309 }
2310 
2312  mlir::OperandRange operands, mlir::TypeRange types,
2313  std::optional<mlir::ArrayAttr> gangArgTypes,
2314  std::optional<mlir::ArrayAttr> deviceTypes,
2315  std::optional<mlir::DenseI32ArrayAttr> segments,
2316  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2317 
2318  if (operands.begin() == operands.end() &&
2319  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2320  return;
2321  }
2322 
2323  p << "(";
2324 
2325  printDeviceTypes(p, gangOnlyDeviceTypes);
2326 
2327  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2328  hasDeviceTypeValues(deviceTypes))
2329  p << ", ";
2330 
2331  if (hasDeviceTypeValues(deviceTypes)) {
2332  unsigned opIdx = 0;
2333  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2334  p << "{";
2335  llvm::interleaveComma(
2336  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2337  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2338  (*gangArgTypes)[opIdx]);
2339  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2340  p << LoopOp::getGangNumKeyword();
2341  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2342  p << LoopOp::getGangDimKeyword();
2343  else if (gangArgTypeAttr.getValue() ==
2344  mlir::acc::GangArgType::Static)
2345  p << LoopOp::getGangStaticKeyword();
2346  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2347  ++opIdx;
2348  });
2349  p << "}";
2350  printSingleDeviceType(p, it.value());
2351  });
2352  }
2353  p << ")";
2354 }
2355 
2357  std::optional<mlir::ArrayAttr> segments,
2358  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2359  if (!segments)
2360  return false;
2361  for (auto attr : *segments) {
2362  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2363  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2364  return true;
2365  }
2366  return false;
2367 }
2368 
2369 /// Check for duplicates in the DeviceType array attribute.
2370 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2371  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2372  if (!deviceTypes)
2373  return success();
2374  for (auto attr : deviceTypes) {
2375  auto deviceTypeAttr =
2376  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2377  if (!deviceTypeAttr)
2378  return failure();
2379  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2380  return failure();
2381  }
2382  return success();
2383 }
2384 
2385 LogicalResult acc::LoopOp::verify() {
2386  if (getUpperbound().size() != getStep().size())
2387  return emitError() << "number of upperbounds expected to be the same as "
2388  "number of steps";
2389 
2390  if (getUpperbound().size() != getLowerbound().size())
2391  return emitError() << "number of upperbounds expected to be the same as "
2392  "number of lowerbounds";
2393 
2394  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2395  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2396  return emitError() << "inclusiveUpperbound size is expected to be the same"
2397  << " as upperbound size";
2398 
2399  // Check collapse
2400  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2401  return emitOpError() << "collapse device_type attr must be define when"
2402  << " collapse attr is present";
2403 
2404  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2405  getCollapseAttr().getValue().size() !=
2406  getCollapseDeviceTypeAttr().getValue().size())
2407  return emitOpError() << "collapse attribute count must match collapse"
2408  << " device_type count";
2409  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2410  return emitOpError()
2411  << "duplicate device_type found in collapseDeviceType attribute";
2412 
2413  // Check gang
2414  if (!getGangOperands().empty()) {
2415  if (!getGangOperandsArgType())
2416  return emitOpError() << "gangOperandsArgType attribute must be defined"
2417  << " when gang operands are present";
2418 
2419  if (getGangOperands().size() !=
2420  getGangOperandsArgTypeAttr().getValue().size())
2421  return emitOpError() << "gangOperandsArgType attribute count must match"
2422  << " gangOperands count";
2423  }
2424  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2425  return emitOpError() << "duplicate device_type found in gang attribute";
2426 
2428  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2429  getGangOperandsDeviceTypeAttr(), "gang")))
2430  return failure();
2431 
2432  // Check worker
2433  if (failed(checkDeviceTypes(getWorkerAttr())))
2434  return emitOpError() << "duplicate device_type found in worker attribute";
2435  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2436  return emitOpError() << "duplicate device_type found in "
2437  "workerNumOperandsDeviceType attribute";
2438  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2439  getWorkerNumOperandsDeviceTypeAttr(),
2440  "worker")))
2441  return failure();
2442 
2443  // Check vector
2444  if (failed(checkDeviceTypes(getVectorAttr())))
2445  return emitOpError() << "duplicate device_type found in vector attribute";
2446  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2447  return emitOpError() << "duplicate device_type found in "
2448  "vectorOperandsDeviceType attribute";
2449  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2450  getVectorOperandsDeviceTypeAttr(),
2451  "vector")))
2452  return failure();
2453 
2455  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2456  getTileOperandsDeviceTypeAttr(), "tile")))
2457  return failure();
2458 
2459  // auto, independent and seq attribute are mutually exclusive.
2460  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2461  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2462  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2463  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2464  return emitError() << "only one of auto, independent, seq can be present "
2465  "at the same time";
2466  }
2467 
2468  // Check that at least one of auto, independent, or seq is present
2469  // for the device-independent default clauses.
2470  auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
2471  return attr.getValue() == mlir::acc::DeviceType::None;
2472  };
2473  bool hasDefaultSeq =
2474  getSeqAttr()
2475  ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2476  hasDeviceNone)
2477  : false;
2478  bool hasDefaultIndependent =
2479  getIndependentAttr()
2480  ? llvm::any_of(
2481  getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2482  hasDeviceNone)
2483  : false;
2484  bool hasDefaultAuto =
2485  getAuto_Attr()
2486  ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2487  hasDeviceNone)
2488  : false;
2489  if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2490  return emitError()
2491  << "at least one of auto, independent, seq must be present";
2492  }
2493 
2494  // Gang, worker and vector are incompatible with seq.
2495  if (getSeqAttr()) {
2496  for (auto attr : getSeqAttr()) {
2497  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2498  if (hasVector(deviceTypeAttr.getValue()) ||
2499  getVectorValue(deviceTypeAttr.getValue()) ||
2500  hasWorker(deviceTypeAttr.getValue()) ||
2501  getWorkerValue(deviceTypeAttr.getValue()) ||
2502  hasGang(deviceTypeAttr.getValue()) ||
2503  getGangValue(mlir::acc::GangArgType::Num,
2504  deviceTypeAttr.getValue()) ||
2505  getGangValue(mlir::acc::GangArgType::Dim,
2506  deviceTypeAttr.getValue()) ||
2507  getGangValue(mlir::acc::GangArgType::Static,
2508  deviceTypeAttr.getValue()))
2509  return emitError() << "gang, worker or vector cannot appear with seq";
2510  }
2511  }
2512 
2513  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2514  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2515  "privatizations", false)))
2516  return failure();
2517 
2518  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2519  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2520  "reductions", false)))
2521  return failure();
2522 
2523  if (getCombined().has_value() &&
2524  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2525  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2526  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2527  return emitError("unexpected combined constructs attribute");
2528  }
2529 
2530  // Check non-empty body().
2531  if (getRegion().empty())
2532  return emitError("expected non-empty body.");
2533 
2534  // When it is container-like - it is expected to hold a loop-like operation.
2535  if (isContainerLike()) {
2536  // Obtain the maximum collapse count - we use this to check that there
2537  // are enough loops contained.
2538  uint64_t collapseCount = getCollapseValue().value_or(1);
2539  if (getCollapseAttr()) {
2540  for (auto collapseEntry : getCollapseAttr()) {
2541  auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2542  if (intAttr.getValue().getZExtValue() > collapseCount)
2543  collapseCount = intAttr.getValue().getZExtValue();
2544  }
2545  }
2546 
2547  // We want to check that we find enough loop-like operations inside.
2548  // PreOrder walk allows us to walk in a breadth-first manner at each nesting
2549  // level.
2550  mlir::Operation *expectedParent = this->getOperation();
2551  bool foundSibling = false;
2552  getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
2553  if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2554  // This effectively checks that we are not looking at a sibling loop.
2555  if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2556  expectedParent) {
2557  foundSibling = true;
2558  return mlir::WalkResult::interrupt();
2559  }
2560 
2561  collapseCount--;
2562  expectedParent = op;
2563  }
2564  // We found enough contained loops.
2565  if (collapseCount == 0)
2566  return mlir::WalkResult::interrupt();
2567  return mlir::WalkResult::advance();
2568  });
2569 
2570  if (foundSibling)
2571  return emitError("found sibling loops inside container-like acc.loop");
2572  if (collapseCount != 0)
2573  return emitError("failed to find enough loop-like operations inside "
2574  "container-like acc.loop");
2575  }
2576 
2577  return success();
2578 }
2579 
2580 unsigned LoopOp::getNumDataOperands() {
2581  return getReductionOperands().size() + getPrivateOperands().size();
2582 }
2583 
2584 Value LoopOp::getDataOperand(unsigned i) {
2585  unsigned numOptional =
2586  getLowerbound().size() + getUpperbound().size() + getStep().size();
2587  numOptional += getGangOperands().size();
2588  numOptional += getVectorOperands().size();
2589  numOptional += getWorkerNumOperands().size();
2590  numOptional += getTileOperands().size();
2591  numOptional += getCacheOperands().size();
2592  return getOperand(numOptional + i);
2593 }
2594 
2595 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2596 
2597 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2598  return hasDeviceType(getAuto_(), deviceType);
2599 }
2600 
2601 bool LoopOp::hasIndependent() {
2602  return hasIndependent(mlir::acc::DeviceType::None);
2603 }
2604 
2605 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2606  return hasDeviceType(getIndependent(), deviceType);
2607 }
2608 
2609 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2610 
2611 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2612  return hasDeviceType(getSeq(), deviceType);
2613 }
2614 
2615 mlir::Value LoopOp::getVectorValue() {
2616  return getVectorValue(mlir::acc::DeviceType::None);
2617 }
2618 
2619 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2620  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2621  getVectorOperands(), deviceType);
2622 }
2623 
2624 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2625 
2626 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2627  return hasDeviceType(getVector(), deviceType);
2628 }
2629 
2630 mlir::Value LoopOp::getWorkerValue() {
2631  return getWorkerValue(mlir::acc::DeviceType::None);
2632 }
2633 
2634 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2635  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2636  getWorkerNumOperands(), deviceType);
2637 }
2638 
2639 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2640 
2641 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2642  return hasDeviceType(getWorker(), deviceType);
2643 }
2644 
2645 mlir::Operation::operand_range LoopOp::getTileValues() {
2646  return getTileValues(mlir::acc::DeviceType::None);
2647 }
2648 
2650 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2651  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2652  getTileOperandsSegments(), deviceType);
2653 }
2654 
2655 std::optional<int64_t> LoopOp::getCollapseValue() {
2656  return getCollapseValue(mlir::acc::DeviceType::None);
2657 }
2658 
2659 std::optional<int64_t>
2660 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2661  if (!getCollapseAttr())
2662  return std::nullopt;
2663  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2664  auto intAttr =
2665  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2666  return intAttr.getValue().getZExtValue();
2667  }
2668  return std::nullopt;
2669 }
2670 
2671 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2672  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2673 }
2674 
2675 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2676  mlir::acc::DeviceType deviceType) {
2677  if (getGangOperands().empty())
2678  return {};
2679  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2680  int32_t nbOperandsBefore = 0;
2681  for (unsigned i = 0; i < *pos; ++i)
2682  nbOperandsBefore += (*getGangOperandsSegments())[i];
2684  getGangOperands()
2685  .drop_front(nbOperandsBefore)
2686  .take_front((*getGangOperandsSegments())[*pos]);
2687 
2688  int32_t argTypeIdx = nbOperandsBefore;
2689  for (auto value : values) {
2690  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2691  (*getGangOperandsArgType())[argTypeIdx]);
2692  if (gangArgTypeAttr.getValue() == gangArgType)
2693  return value;
2694  ++argTypeIdx;
2695  }
2696  }
2697  return {};
2698 }
2699 
2700 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2701 
2702 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2703  return hasDeviceType(getGang(), deviceType);
2704 }
2705 
2706 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2707  return {&getRegion()};
2708 }
2709 
2710 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2711 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2712 /// `(` ssa-id-and-type-list `)`
2713 /// region
2714 ParseResult
2717  SmallVectorImpl<Type> &lowerboundType,
2719  SmallVectorImpl<Type> &upperboundType,
2721  SmallVectorImpl<Type> &stepType) {
2722 
2723  SmallVector<OpAsmParser::Argument> inductionVars;
2724  if (succeeded(
2725  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2726  if (parser.parseLParen() ||
2727  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2728  /*allowType=*/true) ||
2729  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2730  parser.parseOperandList(lowerbound, inductionVars.size(),
2732  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2733  parser.parseKeyword("to") || parser.parseLParen() ||
2734  parser.parseOperandList(upperbound, inductionVars.size(),
2736  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2737  parser.parseKeyword("step") || parser.parseLParen() ||
2738  parser.parseOperandList(step, inductionVars.size(),
2740  parser.parseColonTypeList(stepType) || parser.parseRParen())
2741  return failure();
2742  }
2743  return parser.parseRegion(region, inductionVars);
2744 }
2745 
2747  ValueRange lowerbound, TypeRange lowerboundType,
2748  ValueRange upperbound, TypeRange upperboundType,
2749  ValueRange steps, TypeRange stepType) {
2750  ValueRange regionArgs = region.front().getArguments();
2751  if (!regionArgs.empty()) {
2752  p << acc::LoopOp::getControlKeyword() << "(";
2753  llvm::interleaveComma(regionArgs, p,
2754  [&p](Value v) { p << v << " : " << v.getType(); });
2755  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2756  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2757  << " : " << stepType << ") ";
2758  }
2759  p.printRegion(region, /*printEntryBlockArgs=*/false);
2760 }
2761 
2762 void acc::LoopOp::addSeq(MLIRContext *context,
2763  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2764  setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2765  effectiveDeviceTypes));
2766 }
2767 
2768 void acc::LoopOp::addIndependent(
2769  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2770  setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2771  context, getIndependentAttr(), effectiveDeviceTypes));
2772 }
2773 
2774 void acc::LoopOp::addAuto(MLIRContext *context,
2775  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2776  setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2777  effectiveDeviceTypes));
2778 }
2779 
2780 void acc::LoopOp::setCollapseForDeviceTypes(
2781  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2782  llvm::APInt value) {
2784  llvm::SmallVector<mlir::Attribute> newDeviceTypes;
2785 
2786  assert((getCollapseAttr() == nullptr) ==
2787  (getCollapseDeviceTypeAttr() == nullptr));
2788  assert(value.getBitWidth() == 64);
2789 
2790  if (getCollapseAttr()) {
2791  for (const auto &existing :
2792  llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2793  newValues.push_back(std::get<0>(existing));
2794  newDeviceTypes.push_back(std::get<1>(existing));
2795  }
2796  }
2797 
2798  if (effectiveDeviceTypes.empty()) {
2799  // If the effective device-types list is empty, this is before there are any
2800  // being applied by device_type, so this should be added as a 'none'.
2801  newValues.push_back(
2802  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2803  newDeviceTypes.push_back(
2805  } else {
2806  for (DeviceType DT : effectiveDeviceTypes) {
2807  newValues.push_back(
2808  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2809  newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
2810  }
2811  }
2812 
2813  setCollapseAttr(ArrayAttr::get(context, newValues));
2814  setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
2815 }
2816 
2817 void acc::LoopOp::setTileForDeviceTypes(
2818  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2819  ValueRange values) {
2820  llvm::SmallVector<int32_t> segments;
2821  if (getTileOperandsSegments())
2822  llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2823 
2824  setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2825  context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2826  getTileOperandsMutable(), segments));
2827 
2828  setTileOperandsSegments(segments);
2829 }
2830 
2831 void acc::LoopOp::addVectorOperand(
2832  MLIRContext *context, mlir::Value newValue,
2833  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2834  setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2835  context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2836  newValue, getVectorOperandsMutable()));
2837 }
2838 
2839 void acc::LoopOp::addEmptyVector(
2840  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2841  setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
2842  effectiveDeviceTypes));
2843 }
2844 
2845 void acc::LoopOp::addWorkerNumOperand(
2846  MLIRContext *context, mlir::Value newValue,
2847  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2848  setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849  context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2850  newValue, getWorkerNumOperandsMutable()));
2851 }
2852 
2853 void acc::LoopOp::addEmptyWorker(
2854  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2855  setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
2856  effectiveDeviceTypes));
2857 }
2858 
2859 void acc::LoopOp::addEmptyGang(
2860  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2861  setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
2862  effectiveDeviceTypes));
2863 }
2864 
2865 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
2866  auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
2867  return attr.getValue() == dt;
2868  };
2869  auto testFromArr = [=](ArrayAttr arr) -> bool {
2870  return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
2871  };
2872 
2873  if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
2874  return true;
2875  if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
2876  return true;
2877  if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
2878  return true;
2879 
2880  return false;
2881 }
2882 
2883 bool acc::LoopOp::hasDefaultGangWorkerVector() {
2884  return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
2885  hasGang() || getGangValue(GangArgType::Num) ||
2886  getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
2887 }
2888 
2889 void acc::LoopOp::addGangOperands(
2890  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2891  llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
2892  llvm::SmallVector<int32_t> segments;
2893  if (std::optional<ArrayRef<int32_t>> existingSegments =
2894  getGangOperandsSegments())
2895  llvm::copy(*existingSegments, std::back_inserter(segments));
2896 
2897  unsigned beforeCount = segments.size();
2898 
2899  setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2900  context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2901  getGangOperandsMutable(), segments));
2902 
2903  setGangOperandsSegments(segments);
2904 
2905  // This is a bit of extra work to make sure we update the 'types' correctly by
2906  // adding to the types collection the correct number of times. We could
2907  // potentially add something similar to the
2908  // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
2909  // excessive for a one-off case.
2910  unsigned numAdded = segments.size() - beforeCount;
2911 
2912  if (numAdded > 0) {
2914  if (getGangOperandsArgTypeAttr())
2915  llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
2916 
2917  for (auto i : llvm::index_range(0u, numAdded)) {
2918  llvm::transform(argTypes, std::back_inserter(gangTypes),
2919  [=](mlir::acc::GangArgType gangTy) {
2920  return mlir::acc::GangArgTypeAttr::get(context, gangTy);
2921  });
2922  (void)i;
2923  }
2924 
2925  setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
2926  }
2927 }
2928 
2929 //===----------------------------------------------------------------------===//
2930 // DataOp
2931 //===----------------------------------------------------------------------===//
2932 
2933 LogicalResult acc::DataOp::verify() {
2934  // 2.6.5. Data Construct restriction
2935  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2936  // attach, or default clause must appear on a data construct.
2937  if (getOperands().empty() && !getDefaultAttr())
2938  return emitError("at least one operand or the default attribute "
2939  "must appear on the data operation");
2940 
2941  for (mlir::Value operand : getDataClauseOperands())
2942  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2943  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2944  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2945  operand.getDefiningOp()))
2946  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2947  "as defining op");
2948 
2949  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2950  return failure();
2951 
2952  return success();
2953 }
2954 
2955 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2956 
2957 Value DataOp::getDataOperand(unsigned i) {
2958  unsigned numOptional = getIfCond() ? 1 : 0;
2959  numOptional += getAsyncOperands().size() ? 1 : 0;
2960  numOptional += getWaitOperands().size();
2961  return getOperand(numOptional + i);
2962 }
2963 
2964 bool acc::DataOp::hasAsyncOnly() {
2965  return hasAsyncOnly(mlir::acc::DeviceType::None);
2966 }
2967 
2968 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2969  return hasDeviceType(getAsyncOnly(), deviceType);
2970 }
2971 
2972 mlir::Value DataOp::getAsyncValue() {
2973  return getAsyncValue(mlir::acc::DeviceType::None);
2974 }
2975 
2976 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2978  getAsyncOperands(), deviceType);
2979 }
2980 
2981 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2982 
2983 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2984  return hasDeviceType(getWaitOnly(), deviceType);
2985 }
2986 
2987 mlir::Operation::operand_range DataOp::getWaitValues() {
2988  return getWaitValues(mlir::acc::DeviceType::None);
2989 }
2990 
2992 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2994  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2995  getHasWaitDevnum(), deviceType);
2996 }
2997 
2998 mlir::Value DataOp::getWaitDevnum() {
2999  return getWaitDevnum(mlir::acc::DeviceType::None);
3000 }
3001 
3002 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3003  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3004  getWaitOperandsSegments(), getHasWaitDevnum(),
3005  deviceType);
3006 }
3007 
3008 void acc::DataOp::addAsyncOnly(
3009  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3010  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3011  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3012 }
3013 
3014 void acc::DataOp::addAsyncOperand(
3015  MLIRContext *context, mlir::Value newValue,
3016  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3017  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3018  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3019  getAsyncOperandsMutable()));
3020 }
3021 
3022 void acc::DataOp::addWaitOnly(MLIRContext *context,
3023  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3024  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3025  effectiveDeviceTypes));
3026 }
3027 
3028 void acc::DataOp::addWaitOperands(
3029  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3030  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3031 
3032  llvm::SmallVector<int32_t> segments;
3033  if (getWaitOperandsSegments())
3034  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3035 
3036  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3037  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3038  getWaitOperandsMutable(), segments));
3039  setWaitOperandsSegments(segments);
3040 
3042  if (getHasWaitDevnumAttr())
3043  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3044  hasDevnums.insert(
3045  hasDevnums.end(),
3046  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3047  mlir::BoolAttr::get(context, hasDevnum));
3048  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3049 }
3050 
3051 //===----------------------------------------------------------------------===//
3052 // ExitDataOp
3053 //===----------------------------------------------------------------------===//
3054 
3055 LogicalResult acc::ExitDataOp::verify() {
3056  // 2.6.6. Data Exit Directive restriction
3057  // At least one copyout, delete, or detach clause must appear on an exit data
3058  // directive.
3059  if (getDataClauseOperands().empty())
3060  return emitError("at least one operand must be present in dataOperands on "
3061  "the exit data operation");
3062 
3063  // The async attribute represent the async clause without value. Therefore the
3064  // attribute and operand cannot appear at the same time.
3065  if (getAsyncOperand() && getAsync())
3066  return emitError("async attribute cannot appear with asyncOperand");
3067 
3068  // The wait attribute represent the wait clause without values. Therefore the
3069  // attribute and operands cannot appear at the same time.
3070  if (!getWaitOperands().empty() && getWait())
3071  return emitError("wait attribute cannot appear with waitOperands");
3072 
3073  if (getWaitDevnum() && getWaitOperands().empty())
3074  return emitError("wait_devnum cannot appear without waitOperands");
3075 
3076  return success();
3077 }
3078 
3079 unsigned ExitDataOp::getNumDataOperands() {
3080  return getDataClauseOperands().size();
3081 }
3082 
3083 Value ExitDataOp::getDataOperand(unsigned i) {
3084  unsigned numOptional = getIfCond() ? 1 : 0;
3085  numOptional += getAsyncOperand() ? 1 : 0;
3086  numOptional += getWaitDevnum() ? 1 : 0;
3087  return getOperand(getWaitOperands().size() + numOptional + i);
3088 }
3089 
3090 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3091  MLIRContext *context) {
3092  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3093 }
3094 
3095 //===----------------------------------------------------------------------===//
3096 // EnterDataOp
3097 //===----------------------------------------------------------------------===//
3098 
3099 LogicalResult acc::EnterDataOp::verify() {
3100  // 2.6.6. Data Enter Directive restriction
3101  // At least one copyin, create, or attach clause must appear on an enter data
3102  // directive.
3103  if (getDataClauseOperands().empty())
3104  return emitError("at least one operand must be present in dataOperands on "
3105  "the enter data operation");
3106 
3107  // The async attribute represent the async clause without value. Therefore the
3108  // attribute and operand cannot appear at the same time.
3109  if (getAsyncOperand() && getAsync())
3110  return emitError("async attribute cannot appear with asyncOperand");
3111 
3112  // The wait attribute represent the wait clause without values. Therefore the
3113  // attribute and operands cannot appear at the same time.
3114  if (!getWaitOperands().empty() && getWait())
3115  return emitError("wait attribute cannot appear with waitOperands");
3116 
3117  if (getWaitDevnum() && getWaitOperands().empty())
3118  return emitError("wait_devnum cannot appear without waitOperands");
3119 
3120  for (mlir::Value operand : getDataClauseOperands())
3121  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3122  operand.getDefiningOp()))
3123  return emitError("expect data entry operation as defining op");
3124 
3125  return success();
3126 }
3127 
3128 unsigned EnterDataOp::getNumDataOperands() {
3129  return getDataClauseOperands().size();
3130 }
3131 
3132 Value EnterDataOp::getDataOperand(unsigned i) {
3133  unsigned numOptional = getIfCond() ? 1 : 0;
3134  numOptional += getAsyncOperand() ? 1 : 0;
3135  numOptional += getWaitDevnum() ? 1 : 0;
3136  return getOperand(getWaitOperands().size() + numOptional + i);
3137 }
3138 
3139 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3140  MLIRContext *context) {
3141  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3142 }
3143 
3144 //===----------------------------------------------------------------------===//
3145 // AtomicReadOp
3146 //===----------------------------------------------------------------------===//
3147 
3148 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3149 
3150 //===----------------------------------------------------------------------===//
3151 // AtomicWriteOp
3152 //===----------------------------------------------------------------------===//
3153 
3154 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3155 
3156 //===----------------------------------------------------------------------===//
3157 // AtomicUpdateOp
3158 //===----------------------------------------------------------------------===//
3159 
3160 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3161  PatternRewriter &rewriter) {
3162  if (op.isNoOp()) {
3163  rewriter.eraseOp(op);
3164  return success();
3165  }
3166 
3167  if (Value writeVal = op.getWriteOpVal()) {
3168  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3169  return success();
3170  }
3171 
3172  return failure();
3173 }
3174 
3175 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3176 
3177 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3178 
3179 //===----------------------------------------------------------------------===//
3180 // AtomicCaptureOp
3181 //===----------------------------------------------------------------------===//
3182 
3183 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3184  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3185  return op;
3186  return dyn_cast<AtomicReadOp>(getSecondOp());
3187 }
3188 
3189 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3190  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3191  return op;
3192  return dyn_cast<AtomicWriteOp>(getSecondOp());
3193 }
3194 
3195 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3196  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3197  return op;
3198  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3199 }
3200 
3201 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3202 
3203 //===----------------------------------------------------------------------===//
3204 // DeclareEnterOp
3205 //===----------------------------------------------------------------------===//
3206 
3207 template <typename Op>
3208 static LogicalResult
3210  bool requireAtLeastOneOperand = true) {
3211  if (operands.empty() && requireAtLeastOneOperand)
3212  return emitError(
3213  op->getLoc(),
3214  "at least one operand must appear on the declare operation");
3215 
3216  for (mlir::Value operand : operands) {
3217  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3218  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3219  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3220  operand.getDefiningOp()))
3221  return op.emitError(
3222  "expect valid declare data entry operation or acc.getdeviceptr "
3223  "as defining op");
3224 
3225  mlir::Value var{getVar(operand.getDefiningOp())};
3226  assert(var && "declare operands can only be data entry operations which "
3227  "must have var");
3228  (void)var;
3229  std::optional<mlir::acc::DataClause> dataClauseOptional{
3230  getDataClause(operand.getDefiningOp())};
3231  assert(dataClauseOptional.has_value() &&
3232  "declare operands can only be data entry operations which must have "
3233  "dataClause");
3234  (void)dataClauseOptional;
3235  }
3236 
3237  return success();
3238 }
3239 
3240 LogicalResult acc::DeclareEnterOp::verify() {
3241  return checkDeclareOperands(*this, this->getDataClauseOperands());
3242 }
3243 
3244 //===----------------------------------------------------------------------===//
3245 // DeclareExitOp
3246 //===----------------------------------------------------------------------===//
3247 
3248 LogicalResult acc::DeclareExitOp::verify() {
3249  if (getToken())
3250  return checkDeclareOperands(*this, this->getDataClauseOperands(),
3251  /*requireAtLeastOneOperand=*/false);
3252  return checkDeclareOperands(*this, this->getDataClauseOperands());
3253 }
3254 
3255 //===----------------------------------------------------------------------===//
3256 // DeclareOp
3257 //===----------------------------------------------------------------------===//
3258 
3259 LogicalResult acc::DeclareOp::verify() {
3260  return checkDeclareOperands(*this, this->getDataClauseOperands());
3261 }
3262 
3263 //===----------------------------------------------------------------------===//
3264 // RoutineOp
3265 //===----------------------------------------------------------------------===//
3266 
3267 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3268  acc::DeviceType dtype) {
3269  unsigned parallelism = 0;
3270  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3271  parallelism += op.hasWorker(dtype) ? 1 : 0;
3272  parallelism += op.hasVector(dtype) ? 1 : 0;
3273  parallelism += op.hasSeq(dtype) ? 1 : 0;
3274  return parallelism;
3275 }
3276 
3277 LogicalResult acc::RoutineOp::verify() {
3278  unsigned baseParallelism =
3280 
3281  if (baseParallelism > 1)
3282  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3283  "be present at the same time";
3284 
3285  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3286  ++dtypeInt) {
3287  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3288  if (dtype == acc::DeviceType::None)
3289  continue;
3290  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
3291 
3292  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3293  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3294  "be present at the same time";
3295  }
3296 
3297  return success();
3298 }
3299 
3300 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3301  mlir::ArrayAttr &deviceTypes) {
3302  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3303  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3304 
3305  if (failed(parser.parseCommaSeparatedList([&]() {
3306  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3307  return failure();
3308  if (failed(parser.parseOptionalLSquare())) {
3309  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3310  parser.getContext(), mlir::acc::DeviceType::None));
3311  } else {
3312  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3313  parser.parseRSquare())
3314  return failure();
3315  }
3316  return success();
3317  })))
3318  return failure();
3319 
3320  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3321  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3322 
3323  return success();
3324 }
3325 
3327  std::optional<mlir::ArrayAttr> bindName,
3328  std::optional<mlir::ArrayAttr> deviceTypes) {
3329  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3330  [&](const auto &pair) {
3331  p << std::get<0>(pair);
3332  printSingleDeviceType(p, std::get<1>(pair));
3333  });
3334 }
3335 
3336 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
3337  mlir::ArrayAttr &gang,
3338  mlir::ArrayAttr &gangDim,
3339  mlir::ArrayAttr &gangDimDeviceTypes) {
3340 
3341  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
3342  gangDimDeviceTypeAttrs;
3343  bool needCommaBeforeOperands = false;
3344 
3345  // Gang keyword only
3346  if (failed(parser.parseOptionalLParen())) {
3347  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3349  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3350  return success();
3351  }
3352 
3353  // Parse keyword only attributes
3354  if (succeeded(parser.parseOptionalLSquare())) {
3355  if (failed(parser.parseCommaSeparatedList([&]() {
3356  if (parser.parseAttribute(gangAttrs.emplace_back()))
3357  return failure();
3358  return success();
3359  })))
3360  return failure();
3361  if (parser.parseRSquare())
3362  return failure();
3363  needCommaBeforeOperands = true;
3364  }
3365 
3366  if (needCommaBeforeOperands && failed(parser.parseComma()))
3367  return failure();
3368 
3369  if (failed(parser.parseCommaSeparatedList([&]() {
3370  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3371  parser.parseColon() ||
3372  parser.parseAttribute(gangDimAttrs.emplace_back()))
3373  return failure();
3374  if (succeeded(parser.parseOptionalLSquare())) {
3375  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3376  parser.parseRSquare())
3377  return failure();
3378  } else {
3379  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3380  parser.getContext(), mlir::acc::DeviceType::None));
3381  }
3382  return success();
3383  })))
3384  return failure();
3385 
3386  if (failed(parser.parseRParen()))
3387  return failure();
3388 
3389  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3390  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
3391  gangDimDeviceTypes =
3392  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
3393 
3394  return success();
3395 }
3396 
3398  std::optional<mlir::ArrayAttr> gang,
3399  std::optional<mlir::ArrayAttr> gangDim,
3400  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3401 
3402  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
3403  gang->size() == 1) {
3404  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3405  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3406  return;
3407  }
3408 
3409  p << "(";
3410 
3411  printDeviceTypes(p, gang);
3412 
3413  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
3414  p << ", ";
3415 
3416  if (hasDeviceTypeValues(gangDimDeviceTypes))
3417  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3418  [&](const auto &pair) {
3419  p << acc::RoutineOp::getGangDimKeyword() << ": ";
3420  p << std::get<0>(pair);
3421  printSingleDeviceType(p, std::get<1>(pair));
3422  });
3423 
3424  p << ")";
3425 }
3426 
3427 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
3428  mlir::ArrayAttr &deviceTypes) {
3430  // Keyword only
3431  if (failed(parser.parseOptionalLParen())) {
3432  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
3434  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3435  return success();
3436  }
3437 
3438  // Parse device type attributes
3439  if (succeeded(parser.parseOptionalLSquare())) {
3440  if (failed(parser.parseCommaSeparatedList([&]() {
3441  if (parser.parseAttribute(attributes.emplace_back()))
3442  return failure();
3443  return success();
3444  })))
3445  return failure();
3446  if (parser.parseRSquare() || parser.parseRParen())
3447  return failure();
3448  }
3449  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3450  return success();
3451 }
3452 
3453 static void
3455  std::optional<mlir::ArrayAttr> deviceTypes) {
3456 
3457  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
3458  auto deviceTypeAttr =
3459  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3460  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3461  return;
3462  }
3463 
3464  if (!hasDeviceTypeValues(deviceTypes))
3465  return;
3466 
3467  p << "([";
3468  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
3469  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3470  p << dTypeAttr;
3471  });
3472  p << "])";
3473 }
3474 
3475 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3476 
3477 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3478  return hasDeviceType(getWorker(), deviceType);
3479 }
3480 
3481 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3482 
3483 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3484  return hasDeviceType(getVector(), deviceType);
3485 }
3486 
3487 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3488 
3489 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3490  return hasDeviceType(getSeq(), deviceType);
3491 }
3492 
3493 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3494  return getBindNameValue(mlir::acc::DeviceType::None);
3495 }
3496 
3497 std::optional<llvm::StringRef>
3498 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3499  if (!hasDeviceTypeValues(getBindNameDeviceType()))
3500  return std::nullopt;
3501  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3502  auto attr = (*getBindName())[*pos];
3503  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3504  return stringAttr.getValue();
3505  }
3506  return std::nullopt;
3507 }
3508 
3509 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3510 
3511 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3512  return hasDeviceType(getGang(), deviceType);
3513 }
3514 
3515 std::optional<int64_t> RoutineOp::getGangDimValue() {
3516  return getGangDimValue(mlir::acc::DeviceType::None);
3517 }
3518 
3519 std::optional<int64_t>
3520 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3521  if (!hasDeviceTypeValues(getGangDimDeviceType()))
3522  return std::nullopt;
3523  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
3524  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3525  return intAttr.getInt();
3526  }
3527  return std::nullopt;
3528 }
3529 
3530 //===----------------------------------------------------------------------===//
3531 // InitOp
3532 //===----------------------------------------------------------------------===//
3533 
3534 LogicalResult acc::InitOp::verify() {
3535  Operation *currOp = *this;
3536  while ((currOp = currOp->getParentOp()))
3537  if (isComputeOperation(currOp))
3538  return emitOpError("cannot be nested in a compute operation");
3539  return success();
3540 }
3541 
3542 void acc::InitOp::addDeviceType(MLIRContext *context,
3543  mlir::acc::DeviceType deviceType) {
3545  if (getDeviceTypesAttr())
3546  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3547 
3548  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3549  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3550 }
3551 
3552 //===----------------------------------------------------------------------===//
3553 // ShutdownOp
3554 //===----------------------------------------------------------------------===//
3555 
3556 LogicalResult acc::ShutdownOp::verify() {
3557  Operation *currOp = *this;
3558  while ((currOp = currOp->getParentOp()))
3559  if (isComputeOperation(currOp))
3560  return emitOpError("cannot be nested in a compute operation");
3561  return success();
3562 }
3563 
3564 void acc::ShutdownOp::addDeviceType(MLIRContext *context,
3565  mlir::acc::DeviceType deviceType) {
3567  if (getDeviceTypesAttr())
3568  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3569 
3570  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3571  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3572 }
3573 
3574 //===----------------------------------------------------------------------===//
3575 // SetOp
3576 //===----------------------------------------------------------------------===//
3577 
3578 LogicalResult acc::SetOp::verify() {
3579  Operation *currOp = *this;
3580  while ((currOp = currOp->getParentOp()))
3581  if (isComputeOperation(currOp))
3582  return emitOpError("cannot be nested in a compute operation");
3583  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3584  return emitOpError("at least one default_async, device_num, or device_type "
3585  "operand must appear");
3586  return success();
3587 }
3588 
3589 //===----------------------------------------------------------------------===//
3590 // UpdateOp
3591 //===----------------------------------------------------------------------===//
3592 
3593 LogicalResult acc::UpdateOp::verify() {
3594  // At least one of host or device should have a value.
3595  if (getDataClauseOperands().empty())
3596  return emitError("at least one value must be present in dataOperands");
3597 
3598  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
3599  getAsyncOperandsDeviceTypeAttr(),
3600  "async")))
3601  return failure();
3602 
3604  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3605  getWaitOperandsDeviceTypeAttr(), "wait")))
3606  return failure();
3607 
3608  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
3609  return failure();
3610 
3611  for (mlir::Value operand : getDataClauseOperands())
3612  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3613  operand.getDefiningOp()))
3614  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3615  "as defining op");
3616 
3617  return success();
3618 }
3619 
3620 unsigned UpdateOp::getNumDataOperands() {
3621  return getDataClauseOperands().size();
3622 }
3623 
3624 Value UpdateOp::getDataOperand(unsigned i) {
3625  unsigned numOptional = getAsyncOperands().size();
3626  numOptional += getIfCond() ? 1 : 0;
3627  return getOperand(getWaitOperands().size() + numOptional + i);
3628 }
3629 
3630 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3631  MLIRContext *context) {
3632  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
3633 }
3634 
3635 bool UpdateOp::hasAsyncOnly() {
3636  return hasAsyncOnly(mlir::acc::DeviceType::None);
3637 }
3638 
3639 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3640  return hasDeviceType(getAsyncOnly(), deviceType);
3641 }
3642 
3643 mlir::Value UpdateOp::getAsyncValue() {
3644  return getAsyncValue(mlir::acc::DeviceType::None);
3645 }
3646 
3647 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3649  return {};
3650 
3651  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3652  return getAsyncOperands()[*pos];
3653 
3654  return {};
3655 }
3656 
3657 bool UpdateOp::hasWaitOnly() {
3658  return hasWaitOnly(mlir::acc::DeviceType::None);
3659 }
3660 
3661 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3662  return hasDeviceType(getWaitOnly(), deviceType);
3663 }
3664 
3665 mlir::Operation::operand_range UpdateOp::getWaitValues() {
3666  return getWaitValues(mlir::acc::DeviceType::None);
3667 }
3668 
3670 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3672  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3673  getHasWaitDevnum(), deviceType);
3674 }
3675 
3676 mlir::Value UpdateOp::getWaitDevnum() {
3677  return getWaitDevnum(mlir::acc::DeviceType::None);
3678 }
3679 
3680 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3681  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3682  getWaitOperandsSegments(), getHasWaitDevnum(),
3683  deviceType);
3684 }
3685 
3686 //===----------------------------------------------------------------------===//
3687 // WaitOp
3688 //===----------------------------------------------------------------------===//
3689 
3690 LogicalResult acc::WaitOp::verify() {
3691  // The async attribute represent the async clause without value. Therefore the
3692  // attribute and operand cannot appear at the same time.
3693  if (getAsyncOperand() && getAsync())
3694  return emitError("async attribute cannot appear with asyncOperand");
3695 
3696  if (getWaitDevnum() && getWaitOperands().empty())
3697  return emitError("wait_devnum cannot appear without waitOperands");
3698 
3699  return success();
3700 }
3701 
3702 #define GET_OP_CLASSES
3703 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3704 
3705 #define GET_ATTRDEF_CLASSES
3706 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3707 
3708 #define GET_TYPEDEF_CLASSES
3709 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3710 
3711 //===----------------------------------------------------------------------===//
3712 // acc dialect utilities
3713 //===----------------------------------------------------------------------===//
3714 
3717  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3719  accDataClauseOp)
3720  .Case<ACC_DATA_ENTRY_OPS>(
3721  [&](auto entry) { return entry.getVarPtr(); })
3722  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3723  [&](auto exit) { return exit.getVarPtr(); })
3724  .Default([&](mlir::Operation *) {
3726  })};
3727  return varPtr;
3728 }
3729 
3731  auto varPtr{
3733  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3734  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3735  return varPtr;
3736 }
3737 
3739  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3740  .Case<ACC_DATA_ENTRY_OPS>(
3741  [&](auto entry) { return entry.getVarType(); })
3742  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3743  [&](auto exit) { return exit.getVarType(); })
3744  .Default([&](mlir::Operation *) { return mlir::Type(); })};
3745  return varType;
3746 }
3747 
3750  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3752  accDataClauseOp)
3753  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3754  [&](auto dataClause) { return dataClause.getAccPtr(); })
3755  .Default([&](mlir::Operation *) {
3757  })};
3758  return accPtr;
3759 }
3760 
3762  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3764  [&](auto dataClause) { return dataClause.getAccVar(); })
3765  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3766  return accPtr;
3767 }
3768 
3770  auto varPtrPtr{
3772  .Case<ACC_DATA_ENTRY_OPS>(
3773  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3774  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3775  return varPtrPtr;
3776 }
3777 
3782  accDataClauseOp)
3783  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3785  dataClause.getBounds().begin(), dataClause.getBounds().end());
3786  })
3787  .Default([&](mlir::Operation *) {
3789  })};
3790  return bounds;
3791 }
3792 
3796  accDataClauseOp)
3797  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3799  dataClause.getAsyncOperands().begin(),
3800  dataClause.getAsyncOperands().end());
3801  })
3802  .Default([&](mlir::Operation *) {
3804  });
3805 }
3806 
3807 mlir::ArrayAttr
3810  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3811  return dataClause.getAsyncOperandsDeviceTypeAttr();
3812  })
3813  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3814 }
3815 
3816 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3819  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3820  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3821 }
3822 
3823 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3824  auto name{
3826  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3827  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3828  return {};
3829  })};
3830  return name;
3831 }
3832 
3833 std::optional<mlir::acc::DataClause>
3835  auto dataClause{
3837  accDataEntryOp)
3838  .Case<ACC_DATA_ENTRY_OPS>(
3839  [&](auto entry) { return entry.getDataClause(); })
3840  .Default([&](mlir::Operation *) { return std::nullopt; })};
3841  return dataClause;
3842 }
3843 
3845  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3846  .Case<ACC_DATA_ENTRY_OPS>(
3847  [&](auto entry) { return entry.getImplicit(); })
3848  .Default([&](mlir::Operation *) { return false; })};
3849  return implicit;
3850 }
3851 
3853  auto dataOperands{
3856  [&](auto entry) { return entry.getDataClauseOperands(); })
3857  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3858  return dataOperands;
3859 }
3860 
3863  auto dataOperands{
3866  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3867  .Default([&](mlir::Operation *) { return nullptr; })};
3868  return dataOperands;
3869 }
3870 
3872  mlir::Operation *parentOp = region.getParentOp();
3873  while (parentOp) {
3874  if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
3875  return parentOp;
3876  }
3877  parentOp = parentOp->getParentOp();
3878  }
3879  return nullptr;
3880 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:114
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2263
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:3397
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:742
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:2356
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:1044
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2370
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:756
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1521
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:370
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:3300
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:339
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1532
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1437
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:174
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:3454
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:2165
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1772
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:3209
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:317
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1726
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:194
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:292
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2715
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:973
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1568
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1126
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:348
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:218
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1307
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:324
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2746
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:3427
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:3336
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1420
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1595
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:3326
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1711
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1374
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:957
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:250
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1687
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:411
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:2184
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:832
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1756
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1351
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:205
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:988
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1668
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:180
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:2311
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:234
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1606
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:382
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:270
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:1054
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:3267
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1357
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1792
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:937
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:44
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:52
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:205
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:834
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:3761
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3730
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3749
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:3834
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:3862
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:3779
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:3852
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3823
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
Definition: OpenACC.cpp:3871
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:3844
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:3794
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:3769
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3816
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:3738
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3716
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3808
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.