MLIR  21.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 
25 using namespace mlir;
26 using namespace acc;
27 
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33 
34 namespace {
35 
36 static bool isScalarLikeType(Type type) {
37  return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
38 }
39 
40 struct MemRefPointerLikeModel
41  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
42  MemRefType> {
43  Type getElementType(Type pointer) const {
44  return cast<MemRefType>(pointer).getElementType();
45  }
46  mlir::acc::VariableTypeCategory
47  getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48  Type varType) const {
49  if (auto mappableTy = dyn_cast<MappableType>(varType)) {
50  return mappableTy.getTypeCategory(varPtr);
51  }
52  auto memrefTy = cast<MemRefType>(pointer);
53  if (!memrefTy.hasRank()) {
54  // This memref is unranked - aka it could have any rank, including a
55  // rank of 0 which could mean scalar. For now, return uncategorized.
56  return mlir::acc::VariableTypeCategory::uncategorized;
57  }
58 
59  if (memrefTy.getRank() == 0) {
60  if (isScalarLikeType(memrefTy.getElementType())) {
61  return mlir::acc::VariableTypeCategory::scalar;
62  }
63  // Zero-rank non-scalar - need further analysis to determine the type
64  // category. For now, return uncategorized.
65  return mlir::acc::VariableTypeCategory::uncategorized;
66  }
67 
68  // It has a rank - must be an array.
69  assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70  return mlir::acc::VariableTypeCategory::array;
71  }
72 };
73 
74 struct LLVMPointerPointerLikeModel
75  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76  LLVM::LLVMPointerType> {
77  Type getElementType(Type pointer) const { return Type(); }
78 };
79 
80 /// Helper function for any of the times we need to modify an ArrayAttr based on
81 /// a device type list. Returns a new ArrayAttr with all of the
82 /// existingDeviceTypes, plus the effective new ones(or an added none if hte new
83 /// list is empty).
84 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
86  llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
88  if (existingDeviceTypes)
89  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
90 
91  if (newDeviceTypes.empty())
92  deviceTypes.push_back(
94 
95  for (DeviceType DT : newDeviceTypes)
96  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
97 
98  return mlir::ArrayAttr::get(context, deviceTypes);
99 }
100 
101 /// Helper function for any of the times we need to add operands that are
102 /// affected by a device type list. Returns a new ArrayAttr with all of the
103 /// existingDeviceTypes, plus the effective new ones (or an added none, if the
104 /// new list is empty). Additionally, adds the arguments to the argCollection
105 /// the correct number of times. This will also update a 'segments' array, even
106 /// if it won't be used.
107 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
109  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
110  mlir::MutableOperandRange argCollection,
111  llvm::SmallVector<int32_t> &segments) {
113  if (existingDeviceTypes)
114  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
115 
116  if (newDeviceTypes.empty()) {
117  argCollection.append(arguments);
118  segments.push_back(arguments.size());
119  deviceTypes.push_back(
121  }
122 
123  for (DeviceType DT : newDeviceTypes) {
124  argCollection.append(arguments);
125  segments.push_back(arguments.size());
126  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
127  }
128 
129  return mlir::ArrayAttr::get(context, deviceTypes);
130 }
131 
132 /// Overload for when the 'segments' aren't needed.
133 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
135  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
136  mlir::MutableOperandRange argCollection) {
138  return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139  newDeviceTypes, arguments,
140  argCollection, segments);
141 }
142 } // namespace
143 
144 //===----------------------------------------------------------------------===//
145 // OpenACC operations
146 //===----------------------------------------------------------------------===//
147 
148 void OpenACCDialect::initialize() {
149  addOperations<
150 #define GET_OP_LIST
151 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
152  >();
153  addAttributes<
154 #define GET_ATTRDEF_LIST
155 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
156  >();
157  addTypes<
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
160  >();
161 
162  // By attaching interfaces here, we make the OpenACC dialect dependent on
163  // the other dialects. This is probably better than having dialects like LLVM
164  // and memref be dependent on OpenACC.
165  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
166  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
167  *getContext());
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // device_type support helpers
172 //===----------------------------------------------------------------------===//
173 
174 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
175  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
176  return true;
177  return false;
178 }
179 
180 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
181  mlir::acc::DeviceType deviceType) {
182  if (!hasDeviceTypeValues(arrayAttr))
183  return false;
184 
185  for (auto attr : *arrayAttr) {
186  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
187  if (deviceTypeAttr.getValue() == deviceType)
188  return true;
189  }
190 
191  return false;
192 }
193 
195  std::optional<mlir::ArrayAttr> deviceTypes) {
196  if (!hasDeviceTypeValues(deviceTypes))
197  return;
198 
199  p << "[";
200  llvm::interleaveComma(*deviceTypes, p,
201  [&](mlir::Attribute attr) { p << attr; });
202  p << "]";
203 }
204 
205 static std::optional<unsigned> findSegment(ArrayAttr segments,
206  mlir::acc::DeviceType deviceType) {
207  unsigned segmentIdx = 0;
208  for (auto attr : segments) {
209  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
210  if (deviceTypeAttr.getValue() == deviceType)
211  return std::make_optional(segmentIdx);
212  ++segmentIdx;
213  }
214  return std::nullopt;
215 }
216 
218 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
220  std::optional<llvm::ArrayRef<int32_t>> segments,
221  mlir::acc::DeviceType deviceType) {
222  if (!arrayAttr)
223  return range.take_front(0);
224  if (auto pos = findSegment(*arrayAttr, deviceType)) {
225  int32_t nbOperandsBefore = 0;
226  for (unsigned i = 0; i < *pos; ++i)
227  nbOperandsBefore += (*segments)[i];
228  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
229  }
230  return range.take_front(0);
231 }
232 
233 static mlir::Value
234 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
236  std::optional<llvm::ArrayRef<int32_t>> segments,
237  std::optional<mlir::ArrayAttr> hasWaitDevnum,
238  mlir::acc::DeviceType deviceType) {
239  if (!hasDeviceTypeValues(deviceTypeAttr))
240  return {};
241  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
242  if (hasWaitDevnum->getValue()[*pos])
243  return getValuesFromSegments(deviceTypeAttr, operands, segments,
244  deviceType)
245  .front();
246  return {};
247 }
248 
250 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
252  std::optional<llvm::ArrayRef<int32_t>> segments,
253  std::optional<mlir::ArrayAttr> hasWaitDevnum,
254  mlir::acc::DeviceType deviceType) {
255  auto range =
256  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
257  if (range.empty())
258  return range;
259  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
260  if (hasWaitDevnum && *hasWaitDevnum) {
261  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
262  if (boolAttr.getValue())
263  return range.drop_front(1); // first value is devnum
264  }
265  }
266  return range;
267 }
268 
269 template <typename Op>
270 static LogicalResult checkWaitAndAsyncConflict(Op op) {
271  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272  ++dtypeInt) {
273  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
274 
275  // The asyncOnly attribute represent the async clause without value.
276  // Therefore the attribute and operand cannot appear at the same time.
277  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
278  op.hasAsyncOnly(dtype))
279  return op.emitError(
280  "asyncOnly attribute cannot appear with asyncOperand");
281 
282  // The wait attribute represent the wait clause without values. Therefore
283  // the attribute and operands cannot appear at the same time.
284  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
285  op.hasWaitOnly(dtype))
286  return op.emitError("wait attribute cannot appear with waitOperands");
287  }
288  return success();
289 }
290 
291 template <typename Op>
292 static LogicalResult checkVarAndVarType(Op op) {
293  if (!op.getVar())
294  return op.emitError("must have var operand");
295 
296  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297  mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
298  // TODO: If a type implements both interfaces (mappable and pointer-like),
299  // it is unclear which semantics to apply without additional info which
300  // would need captured in the data operation. For now restrict this case
301  // unless a compelling reason to support disambiguating between the two.
302  return op.emitError("var must be mappable or pointer-like (not both)");
303  }
304 
305  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
306  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
307  return op.emitError("var must be mappable or pointer-like");
308 
309  if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
310  op.getVarType() != op.getVar().getType())
311  return op.emitError("varType must match when var is mappable");
312 
313  return success();
314 }
315 
316 template <typename Op>
317 static LogicalResult checkVarAndAccVar(Op op) {
318  if (op.getVar().getType() != op.getAccVar().getType())
319  return op.emitError("input and output types must match");
320 
321  return success();
322 }
323 
324 static ParseResult parseVar(mlir::OpAsmParser &parser,
326  // Either `var` or `varPtr` keyword is required.
327  if (failed(parser.parseOptionalKeyword("varPtr"))) {
328  if (failed(parser.parseKeyword("var")))
329  return failure();
330  }
331  if (failed(parser.parseLParen()))
332  return failure();
333  if (failed(parser.parseOperand(var)))
334  return failure();
335 
336  return success();
337 }
338 
340  mlir::Value var) {
341  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
342  p << "varPtr(";
343  else
344  p << "var(";
345  p.printOperand(var);
346 }
347 
348 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
350  mlir::Type &accVarType) {
351  // Either `accVar` or `accPtr` keyword is required.
352  if (failed(parser.parseOptionalKeyword("accPtr"))) {
353  if (failed(parser.parseKeyword("accVar")))
354  return failure();
355  }
356  if (failed(parser.parseLParen()))
357  return failure();
358  if (failed(parser.parseOperand(var)))
359  return failure();
360  if (failed(parser.parseColon()))
361  return failure();
362  if (failed(parser.parseType(accVarType)))
363  return failure();
364  if (failed(parser.parseRParen()))
365  return failure();
366 
367  return success();
368 }
369 
371  mlir::Value accVar, mlir::Type accVarType) {
372  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
373  p << "accPtr(";
374  else
375  p << "accVar(";
376  p.printOperand(accVar);
377  p << " : ";
378  p.printType(accVarType);
379  p << ")";
380 }
381 
382 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
383  mlir::Type &varPtrType,
384  mlir::TypeAttr &varTypeAttr) {
385  if (failed(parser.parseType(varPtrType)))
386  return failure();
387  if (failed(parser.parseRParen()))
388  return failure();
389 
390  if (succeeded(parser.parseOptionalKeyword("varType"))) {
391  if (failed(parser.parseLParen()))
392  return failure();
393  mlir::Type varType;
394  if (failed(parser.parseType(varType)))
395  return failure();
396  varTypeAttr = mlir::TypeAttr::get(varType);
397  if (failed(parser.parseRParen()))
398  return failure();
399  } else {
400  // Set `varType` from the element type of the type of `varPtr`.
401  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
402  varTypeAttr = mlir::TypeAttr::get(
403  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
404  else
405  varTypeAttr = mlir::TypeAttr::get(varPtrType);
406  }
407 
408  return success();
409 }
410 
412  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
413  p.printType(varPtrType);
414  p << ")";
415 
416  // Print the `varType` only if it differs from the element type of
417  // `varPtr`'s type.
418  mlir::Type varType = varTypeAttr.getValue();
419  mlir::Type typeToCheckAgainst =
420  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
421  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
422  : varPtrType;
423  if (typeToCheckAgainst != varType) {
424  p << " varType(";
425  p.printType(varType);
426  p << ")";
427  }
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // DataBoundsOp
432 //===----------------------------------------------------------------------===//
433 LogicalResult acc::DataBoundsOp::verify() {
434  auto extent = getExtent();
435  auto upperbound = getUpperbound();
436  if (!extent && !upperbound)
437  return emitError("expected extent or upperbound.");
438  return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // PrivateOp
443 //===----------------------------------------------------------------------===//
444 LogicalResult acc::PrivateOp::verify() {
445  if (getDataClause() != acc::DataClause::acc_private)
446  return emitError(
447  "data clause associated with private operation must match its intent");
448  if (failed(checkVarAndVarType(*this)))
449  return failure();
450  return success();
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // FirstprivateOp
455 //===----------------------------------------------------------------------===//
456 LogicalResult acc::FirstprivateOp::verify() {
457  if (getDataClause() != acc::DataClause::acc_firstprivate)
458  return emitError("data clause associated with firstprivate operation must "
459  "match its intent");
460  if (failed(checkVarAndVarType(*this)))
461  return failure();
462  return success();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // ReductionOp
467 //===----------------------------------------------------------------------===//
468 LogicalResult acc::ReductionOp::verify() {
469  if (getDataClause() != acc::DataClause::acc_reduction)
470  return emitError("data clause associated with reduction operation must "
471  "match its intent");
472  if (failed(checkVarAndVarType(*this)))
473  return failure();
474  return success();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // DevicePtrOp
479 //===----------------------------------------------------------------------===//
480 LogicalResult acc::DevicePtrOp::verify() {
481  if (getDataClause() != acc::DataClause::acc_deviceptr)
482  return emitError("data clause associated with deviceptr operation must "
483  "match its intent");
484  if (failed(checkVarAndVarType(*this)))
485  return failure();
486  if (failed(checkVarAndAccVar(*this)))
487  return failure();
488  return success();
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // PresentOp
493 //===----------------------------------------------------------------------===//
494 LogicalResult acc::PresentOp::verify() {
495  if (getDataClause() != acc::DataClause::acc_present)
496  return emitError(
497  "data clause associated with present operation must match its intent");
498  if (failed(checkVarAndVarType(*this)))
499  return failure();
500  if (failed(checkVarAndAccVar(*this)))
501  return failure();
502  return success();
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // CopyinOp
507 //===----------------------------------------------------------------------===//
508 LogicalResult acc::CopyinOp::verify() {
509  // Test for all clauses this operation can be decomposed from:
510  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
511  getDataClause() != acc::DataClause::acc_copyin_readonly &&
512  getDataClause() != acc::DataClause::acc_copy &&
513  getDataClause() != acc::DataClause::acc_reduction)
514  return emitError(
515  "data clause associated with copyin operation must match its intent"
516  " or specify original clause this operation was decomposed from");
517  if (failed(checkVarAndVarType(*this)))
518  return failure();
519  if (failed(checkVarAndAccVar(*this)))
520  return failure();
521  return success();
522 }
523 
524 bool acc::CopyinOp::isCopyinReadonly() {
525  return getDataClause() == acc::DataClause::acc_copyin_readonly;
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // CreateOp
530 //===----------------------------------------------------------------------===//
531 LogicalResult acc::CreateOp::verify() {
532  // Test for all clauses this operation can be decomposed from:
533  if (getDataClause() != acc::DataClause::acc_create &&
534  getDataClause() != acc::DataClause::acc_create_zero &&
535  getDataClause() != acc::DataClause::acc_copyout &&
536  getDataClause() != acc::DataClause::acc_copyout_zero)
537  return emitError(
538  "data clause associated with create operation must match its intent"
539  " or specify original clause this operation was decomposed from");
540  if (failed(checkVarAndVarType(*this)))
541  return failure();
542  if (failed(checkVarAndAccVar(*this)))
543  return failure();
544  return success();
545 }
546 
547 bool acc::CreateOp::isCreateZero() {
548  // The zero modifier is encoded in the data clause.
549  return getDataClause() == acc::DataClause::acc_create_zero ||
550  getDataClause() == acc::DataClause::acc_copyout_zero;
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // NoCreateOp
555 //===----------------------------------------------------------------------===//
556 LogicalResult acc::NoCreateOp::verify() {
557  if (getDataClause() != acc::DataClause::acc_no_create)
558  return emitError("data clause associated with no_create operation must "
559  "match its intent");
560  if (failed(checkVarAndVarType(*this)))
561  return failure();
562  if (failed(checkVarAndAccVar(*this)))
563  return failure();
564  return success();
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // AttachOp
569 //===----------------------------------------------------------------------===//
570 LogicalResult acc::AttachOp::verify() {
571  if (getDataClause() != acc::DataClause::acc_attach)
572  return emitError(
573  "data clause associated with attach operation must match its intent");
574  if (failed(checkVarAndVarType(*this)))
575  return failure();
576  if (failed(checkVarAndAccVar(*this)))
577  return failure();
578  return success();
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // DeclareDeviceResidentOp
583 //===----------------------------------------------------------------------===//
584 
585 LogicalResult acc::DeclareDeviceResidentOp::verify() {
586  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
587  return emitError("data clause associated with device_resident operation "
588  "must match its intent");
589  if (failed(checkVarAndVarType(*this)))
590  return failure();
591  if (failed(checkVarAndAccVar(*this)))
592  return failure();
593  return success();
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // DeclareLinkOp
598 //===----------------------------------------------------------------------===//
599 
600 LogicalResult acc::DeclareLinkOp::verify() {
601  if (getDataClause() != acc::DataClause::acc_declare_link)
602  return emitError(
603  "data clause associated with link operation must match its intent");
604  if (failed(checkVarAndVarType(*this)))
605  return failure();
606  if (failed(checkVarAndAccVar(*this)))
607  return failure();
608  return success();
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // CopyoutOp
613 //===----------------------------------------------------------------------===//
614 LogicalResult acc::CopyoutOp::verify() {
615  // Test for all clauses this operation can be decomposed from:
616  if (getDataClause() != acc::DataClause::acc_copyout &&
617  getDataClause() != acc::DataClause::acc_copyout_zero &&
618  getDataClause() != acc::DataClause::acc_copy &&
619  getDataClause() != acc::DataClause::acc_reduction)
620  return emitError(
621  "data clause associated with copyout operation must match its intent"
622  " or specify original clause this operation was decomposed from");
623  if (!getVar() || !getAccVar())
624  return emitError("must have both host and device pointers");
625  if (failed(checkVarAndVarType(*this)))
626  return failure();
627  if (failed(checkVarAndAccVar(*this)))
628  return failure();
629  return success();
630 }
631 
632 bool acc::CopyoutOp::isCopyoutZero() {
633  return getDataClause() == acc::DataClause::acc_copyout_zero;
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // DeleteOp
638 //===----------------------------------------------------------------------===//
639 LogicalResult acc::DeleteOp::verify() {
640  // Test for all clauses this operation can be decomposed from:
641  if (getDataClause() != acc::DataClause::acc_delete &&
642  getDataClause() != acc::DataClause::acc_create &&
643  getDataClause() != acc::DataClause::acc_create_zero &&
644  getDataClause() != acc::DataClause::acc_copyin &&
645  getDataClause() != acc::DataClause::acc_copyin_readonly &&
646  getDataClause() != acc::DataClause::acc_present &&
647  getDataClause() != acc::DataClause::acc_no_create &&
648  getDataClause() != acc::DataClause::acc_declare_device_resident &&
649  getDataClause() != acc::DataClause::acc_declare_link)
650  return emitError(
651  "data clause associated with delete operation must match its intent"
652  " or specify original clause this operation was decomposed from");
653  if (!getAccVar())
654  return emitError("must have device pointer");
655  return success();
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // DetachOp
660 //===----------------------------------------------------------------------===//
661 LogicalResult acc::DetachOp::verify() {
662  // Test for all clauses this operation can be decomposed from:
663  if (getDataClause() != acc::DataClause::acc_detach &&
664  getDataClause() != acc::DataClause::acc_attach)
665  return emitError(
666  "data clause associated with detach operation must match its intent"
667  " or specify original clause this operation was decomposed from");
668  if (!getAccVar())
669  return emitError("must have device pointer");
670  return success();
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // HostOp
675 //===----------------------------------------------------------------------===//
676 LogicalResult acc::UpdateHostOp::verify() {
677  // Test for all clauses this operation can be decomposed from:
678  if (getDataClause() != acc::DataClause::acc_update_host &&
679  getDataClause() != acc::DataClause::acc_update_self)
680  return emitError(
681  "data clause associated with host operation must match its intent"
682  " or specify original clause this operation was decomposed from");
683  if (!getVar() || !getAccVar())
684  return emitError("must have both host and device pointers");
685  if (failed(checkVarAndVarType(*this)))
686  return failure();
687  if (failed(checkVarAndAccVar(*this)))
688  return failure();
689  return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // DeviceOp
694 //===----------------------------------------------------------------------===//
695 LogicalResult acc::UpdateDeviceOp::verify() {
696  // Test for all clauses this operation can be decomposed from:
697  if (getDataClause() != acc::DataClause::acc_update_device)
698  return emitError(
699  "data clause associated with device operation must match its intent"
700  " or specify original clause this operation was decomposed from");
701  if (failed(checkVarAndVarType(*this)))
702  return failure();
703  if (failed(checkVarAndAccVar(*this)))
704  return failure();
705  return success();
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // UseDeviceOp
710 //===----------------------------------------------------------------------===//
711 LogicalResult acc::UseDeviceOp::verify() {
712  // Test for all clauses this operation can be decomposed from:
713  if (getDataClause() != acc::DataClause::acc_use_device)
714  return emitError(
715  "data clause associated with use_device operation must match its intent"
716  " or specify original clause this operation was decomposed from");
717  if (failed(checkVarAndVarType(*this)))
718  return failure();
719  if (failed(checkVarAndAccVar(*this)))
720  return failure();
721  return success();
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // CacheOp
726 //===----------------------------------------------------------------------===//
727 LogicalResult acc::CacheOp::verify() {
728  // Test for all clauses this operation can be decomposed from:
729  if (getDataClause() != acc::DataClause::acc_cache &&
730  getDataClause() != acc::DataClause::acc_cache_readonly)
731  return emitError(
732  "data clause associated with cache operation must match its intent"
733  " or specify original clause this operation was decomposed from");
734  if (failed(checkVarAndVarType(*this)))
735  return failure();
736  if (failed(checkVarAndAccVar(*this)))
737  return failure();
738  return success();
739 }
740 
741 template <typename StructureOp>
742 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
743  unsigned nRegions = 1) {
744 
745  SmallVector<Region *, 2> regions;
746  for (unsigned i = 0; i < nRegions; ++i)
747  regions.push_back(state.addRegion());
748 
749  for (Region *region : regions)
750  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
751  return failure();
752 
753  return success();
754 }
755 
756 static bool isComputeOperation(Operation *op) {
757  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
758 }
759 
760 namespace {
761 /// Pattern to remove operation without region that have constant false `ifCond`
762 /// and remove the condition from the operation if the `ifCond` is a true
763 /// constant.
764 template <typename OpTy>
765 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
767 
768  LogicalResult matchAndRewrite(OpTy op,
769  PatternRewriter &rewriter) const override {
770  // Early return if there is no condition.
771  Value ifCond = op.getIfCond();
772  if (!ifCond)
773  return failure();
774 
775  IntegerAttr constAttr;
776  if (!matchPattern(ifCond, m_Constant(&constAttr)))
777  return failure();
778  if (constAttr.getInt())
779  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
780  else
781  rewriter.eraseOp(op);
782 
783  return success();
784  }
785 };
786 
787 /// Replaces the given op with the contents of the given single-block region,
788 /// using the operands of the block terminator to replace operation results.
789 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
790  Region &region, ValueRange blockArgs = {}) {
791  assert(llvm::hasSingleElement(region) && "expected single-region block");
792  Block *block = &region.front();
793  Operation *terminator = block->getTerminator();
794  ValueRange results = terminator->getOperands();
795  rewriter.inlineBlockBefore(block, op, blockArgs);
796  rewriter.replaceOp(op, results);
797  rewriter.eraseOp(terminator);
798 }
799 
800 /// Pattern to remove operation with region that have constant false `ifCond`
801 /// and remove the condition from the operation if the `ifCond` is constant
802 /// true.
803 template <typename OpTy>
804 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
806 
807  LogicalResult matchAndRewrite(OpTy op,
808  PatternRewriter &rewriter) const override {
809  // Early return if there is no condition.
810  Value ifCond = op.getIfCond();
811  if (!ifCond)
812  return failure();
813 
814  IntegerAttr constAttr;
815  if (!matchPattern(ifCond, m_Constant(&constAttr)))
816  return failure();
817  if (constAttr.getInt())
818  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
819  else
820  replaceOpWithRegion(rewriter, op, op.getRegion());
821 
822  return success();
823  }
824 };
825 
826 } // namespace
827 
828 //===----------------------------------------------------------------------===//
829 // PrivateRecipeOp
830 //===----------------------------------------------------------------------===//
831 
832 static LogicalResult verifyInitLikeSingleArgRegion(
833  Operation *op, Region &region, StringRef regionType, StringRef regionName,
834  Type type, bool verifyYield, bool optional = false) {
835  if (optional && region.empty())
836  return success();
837 
838  if (region.empty())
839  return op->emitOpError() << "expects non-empty " << regionName << " region";
840  Block &firstBlock = region.front();
841  if (firstBlock.getNumArguments() < 1 ||
842  firstBlock.getArgument(0).getType() != type)
843  return op->emitOpError() << "expects " << regionName
844  << " region first "
845  "argument of the "
846  << regionType << " type";
847 
848  if (verifyYield) {
849  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
850  if (yieldOp.getOperands().size() != 1 ||
851  yieldOp.getOperands().getTypes()[0] != type)
852  return op->emitOpError() << "expects " << regionName
853  << " region to "
854  "yield a value of the "
855  << regionType << " type";
856  }
857  }
858  return success();
859 }
860 
861 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
862  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
863  "privatization", "init", getType(),
864  /*verifyYield=*/false)))
865  return failure();
867  *this, getDestroyRegion(), "privatization", "destroy", getType(),
868  /*verifyYield=*/false, /*optional=*/true)))
869  return failure();
870  return success();
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // FirstprivateRecipeOp
875 //===----------------------------------------------------------------------===//
876 
877 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
878  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
879  "privatization", "init", getType(),
880  /*verifyYield=*/false)))
881  return failure();
882 
883  if (getCopyRegion().empty())
884  return emitOpError() << "expects non-empty copy region";
885 
886  Block &firstBlock = getCopyRegion().front();
887  if (firstBlock.getNumArguments() < 2 ||
888  firstBlock.getArgument(0).getType() != getType())
889  return emitOpError() << "expects copy region with two arguments of the "
890  "privatization type";
891 
892  if (getDestroyRegion().empty())
893  return success();
894 
895  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
896  "privatization", "destroy",
897  getType(), /*verifyYield=*/false)))
898  return failure();
899 
900  return success();
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // ReductionRecipeOp
905 //===----------------------------------------------------------------------===//
906 
907 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
908  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
909  "init", getType(),
910  /*verifyYield=*/false)))
911  return failure();
912 
913  if (getCombinerRegion().empty())
914  return emitOpError() << "expects non-empty combiner region";
915 
916  Block &reductionBlock = getCombinerRegion().front();
917  if (reductionBlock.getNumArguments() < 2 ||
918  reductionBlock.getArgument(0).getType() != getType() ||
919  reductionBlock.getArgument(1).getType() != getType())
920  return emitOpError() << "expects combiner region with the first two "
921  << "arguments of the reduction type";
922 
923  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
924  if (yieldOp.getOperands().size() != 1 ||
925  yieldOp.getOperands().getTypes()[0] != getType())
926  return emitOpError() << "expects combiner region to yield a value "
927  "of the reduction type";
928  }
929 
930  return success();
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // Custom parser and printer verifier for private clause
935 //===----------------------------------------------------------------------===//
936 
937 static ParseResult parseSymOperandList(
938  mlir::OpAsmParser &parser,
940  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
942  if (failed(parser.parseCommaSeparatedList([&]() {
943  if (parser.parseAttribute(attributes.emplace_back()) ||
944  parser.parseArrow() ||
945  parser.parseOperand(operands.emplace_back()) ||
946  parser.parseColonType(types.emplace_back()))
947  return failure();
948  return success();
949  })))
950  return failure();
951  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
952  attributes.end());
953  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
954  return success();
955 }
956 
958  mlir::OperandRange operands,
959  mlir::TypeRange types,
960  std::optional<mlir::ArrayAttr> attributes) {
961  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
962  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
963  << std::get<1>(it).getType();
964  });
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // ParallelOp
969 //===----------------------------------------------------------------------===//
970 
971 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
972 template <typename Op>
973 static LogicalResult checkDataOperands(Op op,
974  const mlir::ValueRange &operands) {
975  for (mlir::Value operand : operands)
976  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
977  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
978  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
979  operand.getDefiningOp()))
980  return op.emitError(
981  "expect data entry/exit operation or acc.getdeviceptr "
982  "as defining op");
983  return success();
984 }
985 
986 template <typename Op>
987 static LogicalResult
988 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
989  mlir::OperandRange operands, llvm::StringRef operandName,
990  llvm::StringRef symbolName, bool checkOperandType = true) {
991  if (!operands.empty()) {
992  if (!attributes || attributes->size() != operands.size())
993  return op->emitOpError()
994  << "expected as many " << symbolName << " symbol reference as "
995  << operandName << " operands";
996  } else {
997  if (attributes)
998  return op->emitOpError()
999  << "unexpected " << symbolName << " symbol reference";
1000  return success();
1001  }
1002 
1004  for (auto args : llvm::zip(operands, *attributes)) {
1005  mlir::Value operand = std::get<0>(args);
1006 
1007  if (!set.insert(operand).second)
1008  return op->emitOpError()
1009  << operandName << " operand appears more than once";
1010 
1011  mlir::Type varType = operand.getType();
1012  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1013  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1014  if (!decl)
1015  return op->emitOpError()
1016  << "expected symbol reference " << symbolRef << " to point to a "
1017  << operandName << " declaration";
1018 
1019  if (checkOperandType && decl.getType() && decl.getType() != varType)
1020  return op->emitOpError() << "expected " << operandName << " (" << varType
1021  << ") to be the same type as " << operandName
1022  << " declaration (" << decl.getType() << ")";
1023  }
1024 
1025  return success();
1026 }
1027 
1028 unsigned ParallelOp::getNumDataOperands() {
1029  return getReductionOperands().size() + getPrivateOperands().size() +
1030  getFirstprivateOperands().size() + getDataClauseOperands().size();
1031 }
1032 
1033 Value ParallelOp::getDataOperand(unsigned i) {
1034  unsigned numOptional = getAsyncOperands().size();
1035  numOptional += getNumGangs().size();
1036  numOptional += getNumWorkers().size();
1037  numOptional += getVectorLength().size();
1038  numOptional += getIfCond() ? 1 : 0;
1039  numOptional += getSelfCond() ? 1 : 0;
1040  return getOperand(getWaitOperands().size() + numOptional + i);
1041 }
1042 
1043 template <typename Op>
1044 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1045  ArrayAttr deviceTypes,
1046  llvm::StringRef keyword) {
1047  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1048  return op.emitOpError() << keyword << " operands count must match "
1049  << keyword << " device_type count";
1050  return success();
1051 }
1052 
1053 template <typename Op>
1055  Op op, OperandRange operands, DenseI32ArrayAttr segments,
1056  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1057  std::size_t numOperandsInSegments = 0;
1058  std::size_t nbOfSegments = 0;
1059 
1060  if (segments) {
1061  for (auto segCount : segments.asArrayRef()) {
1062  if (maxInSegment != 0 && segCount > maxInSegment)
1063  return op.emitOpError() << keyword << " expects a maximum of "
1064  << maxInSegment << " values per segment";
1065  numOperandsInSegments += segCount;
1066  ++nbOfSegments;
1067  }
1068  }
1069 
1070  if ((numOperandsInSegments != operands.size()) ||
1071  (!deviceTypes && !operands.empty()))
1072  return op.emitOpError()
1073  << keyword << " operand count does not match count in segments";
1074  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1075  return op.emitOpError()
1076  << keyword << " segment count does not match device_type count";
1077  return success();
1078 }
1079 
1080 LogicalResult acc::ParallelOp::verify() {
1081  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1082  *this, getPrivatizations(), getPrivateOperands(), "private",
1083  "privatizations", /*checkOperandType=*/false)))
1084  return failure();
1085  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1086  *this, getFirstprivatizations(), getFirstprivateOperands(),
1087  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1088  return failure();
1089  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1090  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1091  "reductions", false)))
1092  return failure();
1093 
1095  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1096  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1097  return failure();
1098 
1100  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1101  getWaitOperandsDeviceTypeAttr(), "wait")))
1102  return failure();
1103 
1104  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1105  getNumWorkersDeviceTypeAttr(),
1106  "num_workers")))
1107  return failure();
1108 
1109  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1110  getVectorLengthDeviceTypeAttr(),
1111  "vector_length")))
1112  return failure();
1113 
1114  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1115  getAsyncOperandsDeviceTypeAttr(),
1116  "async")))
1117  return failure();
1118 
1119  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1120  return failure();
1121 
1122  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1123 }
1124 
1125 static mlir::Value
1126 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1128  mlir::acc::DeviceType deviceType) {
1129  if (!arrayAttr)
1130  return {};
1131  if (auto pos = findSegment(*arrayAttr, deviceType))
1132  return range[*pos];
1133  return {};
1134 }
1135 
1136 bool acc::ParallelOp::hasAsyncOnly() {
1137  return hasAsyncOnly(mlir::acc::DeviceType::None);
1138 }
1139 
1140 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1141  return hasDeviceType(getAsyncOnly(), deviceType);
1142 }
1143 
1144 mlir::Value acc::ParallelOp::getAsyncValue() {
1145  return getAsyncValue(mlir::acc::DeviceType::None);
1146 }
1147 
1148 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1150  getAsyncOperands(), deviceType);
1151 }
1152 
1153 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1154  return getNumWorkersValue(mlir::acc::DeviceType::None);
1155 }
1156 
1158 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1159  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1160  deviceType);
1161 }
1162 
1163 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1164  return getVectorLengthValue(mlir::acc::DeviceType::None);
1165 }
1166 
1168 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1169  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1170  getVectorLength(), deviceType);
1171 }
1172 
1173 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1174  return getNumGangsValues(mlir::acc::DeviceType::None);
1175 }
1176 
1178 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1179  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1180  getNumGangsSegments(), deviceType);
1181 }
1182 
1183 bool acc::ParallelOp::hasWaitOnly() {
1184  return hasWaitOnly(mlir::acc::DeviceType::None);
1185 }
1186 
1187 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1188  return hasDeviceType(getWaitOnly(), deviceType);
1189 }
1190 
1191 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1192  return getWaitValues(mlir::acc::DeviceType::None);
1193 }
1194 
1196 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1198  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1199  getHasWaitDevnum(), deviceType);
1200 }
1201 
1202 mlir::Value ParallelOp::getWaitDevnum() {
1203  return getWaitDevnum(mlir::acc::DeviceType::None);
1204 }
1205 
1206 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1207  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1208  getWaitOperandsSegments(), getHasWaitDevnum(),
1209  deviceType);
1210 }
1211 
1212 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1213  mlir::OperationState &odsState,
1214  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1215  mlir::ValueRange vectorLength,
1216  mlir::ValueRange asyncOperands,
1217  mlir::ValueRange waitOperands, mlir::Value ifCond,
1218  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1219  mlir::ValueRange gangPrivateOperands,
1220  mlir::ValueRange gangFirstPrivateOperands,
1221  mlir::ValueRange dataClauseOperands) {
1222 
1223  ParallelOp::build(
1224  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1225  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1226  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1227  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1228  /*numGangsDeviceType=*/nullptr, numWorkers,
1229  /*numWorkersDeviceType=*/nullptr, vectorLength,
1230  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1231  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1232  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1233  /*firstprivatizations=*/nullptr, dataClauseOperands,
1234  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1235 }
1236 
1237 void acc::ParallelOp::addNumWorkersOperand(
1238  MLIRContext *context, mlir::Value newValue,
1239  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1240  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1241  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1242  getNumWorkersMutable()));
1243 }
1244 void acc::ParallelOp::addVectorLengthOperand(
1245  MLIRContext *context, mlir::Value newValue,
1246  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1247  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1248  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1249  getVectorLengthMutable()));
1250 }
1251 
1252 void acc::ParallelOp::addAsyncOnly(
1253  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1254  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1255  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1256 }
1257 
1258 void acc::ParallelOp::addAsyncOperand(
1259  MLIRContext *context, mlir::Value newValue,
1260  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1261  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1262  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1263  getAsyncOperandsMutable()));
1264 }
1265 
1266 void acc::ParallelOp::addNumGangsOperands(
1267  MLIRContext *context, mlir::ValueRange newValues,
1268  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1269  llvm::SmallVector<int32_t> segments;
1270  if (getNumGangsSegments())
1271  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1272 
1273  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1274  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1275  getNumGangsMutable(), segments));
1276 
1277  setNumGangsSegments(segments);
1278 }
1279 void acc::ParallelOp::addWaitOnly(
1280  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1281  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1282  effectiveDeviceTypes));
1283 }
1284 void acc::ParallelOp::addWaitOperands(
1285  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1286  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1287 
1288  llvm::SmallVector<int32_t> segments;
1289  if (getWaitOperandsSegments())
1290  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1291 
1292  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1293  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1294  getWaitOperandsMutable(), segments));
1295  setWaitOperandsSegments(segments);
1296 
1298  if (getHasWaitDevnumAttr())
1299  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1300  hasDevnums.insert(
1301  hasDevnums.end(),
1302  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1303  mlir::BoolAttr::get(context, hasDevnum));
1304  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1305 }
1306 
1307 static ParseResult parseNumGangs(
1308  mlir::OpAsmParser &parser,
1310  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1311  mlir::DenseI32ArrayAttr &segments) {
1314 
1315  do {
1316  if (failed(parser.parseLBrace()))
1317  return failure();
1318 
1319  int32_t crtOperandsSize = operands.size();
1320  if (failed(parser.parseCommaSeparatedList(
1322  if (parser.parseOperand(operands.emplace_back()) ||
1323  parser.parseColonType(types.emplace_back()))
1324  return failure();
1325  return success();
1326  })))
1327  return failure();
1328  seg.push_back(operands.size() - crtOperandsSize);
1329 
1330  if (failed(parser.parseRBrace()))
1331  return failure();
1332 
1333  if (succeeded(parser.parseOptionalLSquare())) {
1334  if (parser.parseAttribute(attributes.emplace_back()) ||
1335  parser.parseRSquare())
1336  return failure();
1337  } else {
1338  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1340  }
1341  } while (succeeded(parser.parseOptionalComma()));
1342 
1343  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1344  attributes.end());
1345  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1346  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1347 
1348  return success();
1349 }
1350 
1352  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1353  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1354  p << " [" << attr << "]";
1355 }
1356 
1358  mlir::OperandRange operands, mlir::TypeRange types,
1359  std::optional<mlir::ArrayAttr> deviceTypes,
1360  std::optional<mlir::DenseI32ArrayAttr> segments) {
1361  unsigned opIdx = 0;
1362  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1363  p << "{";
1364  llvm::interleaveComma(
1365  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1366  p << operands[opIdx] << " : " << operands[opIdx].getType();
1367  ++opIdx;
1368  });
1369  p << "}";
1370  printSingleDeviceType(p, it.value());
1371  });
1372 }
1373 
1375  mlir::OpAsmParser &parser,
1377  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1378  mlir::DenseI32ArrayAttr &segments) {
1381 
1382  do {
1383  if (failed(parser.parseLBrace()))
1384  return failure();
1385 
1386  int32_t crtOperandsSize = operands.size();
1387 
1388  if (failed(parser.parseCommaSeparatedList(
1390  if (parser.parseOperand(operands.emplace_back()) ||
1391  parser.parseColonType(types.emplace_back()))
1392  return failure();
1393  return success();
1394  })))
1395  return failure();
1396 
1397  seg.push_back(operands.size() - crtOperandsSize);
1398 
1399  if (failed(parser.parseRBrace()))
1400  return failure();
1401 
1402  if (succeeded(parser.parseOptionalLSquare())) {
1403  if (parser.parseAttribute(attributes.emplace_back()) ||
1404  parser.parseRSquare())
1405  return failure();
1406  } else {
1407  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1409  }
1410  } while (succeeded(parser.parseOptionalComma()));
1411 
1412  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1413  attributes.end());
1414  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1415  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1416 
1417  return success();
1418 }
1419 
1422  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1423  std::optional<mlir::DenseI32ArrayAttr> segments) {
1424  unsigned opIdx = 0;
1425  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1426  p << "{";
1427  llvm::interleaveComma(
1428  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1429  p << operands[opIdx] << " : " << operands[opIdx].getType();
1430  ++opIdx;
1431  });
1432  p << "}";
1433  printSingleDeviceType(p, it.value());
1434  });
1435 }
1436 
1437 static ParseResult parseWaitClause(
1438  mlir::OpAsmParser &parser,
1440  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1441  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1442  mlir::ArrayAttr &keywordOnly) {
1443  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1445 
1446  bool needCommaBeforeOperands = false;
1447 
1448  // Keyword only
1449  if (failed(parser.parseOptionalLParen())) {
1450  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1452  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1453  return success();
1454  }
1455 
1456  // Parse keyword only attributes
1457  if (succeeded(parser.parseOptionalLSquare())) {
1458  if (failed(parser.parseCommaSeparatedList([&]() {
1459  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1460  return failure();
1461  return success();
1462  })))
1463  return failure();
1464  if (parser.parseRSquare())
1465  return failure();
1466  needCommaBeforeOperands = true;
1467  }
1468 
1469  if (needCommaBeforeOperands && failed(parser.parseComma()))
1470  return failure();
1471 
1472  do {
1473  if (failed(parser.parseLBrace()))
1474  return failure();
1475 
1476  int32_t crtOperandsSize = operands.size();
1477 
1478  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1479  if (failed(parser.parseColon()))
1480  return failure();
1481  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1482  } else {
1483  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1484  }
1485 
1486  if (failed(parser.parseCommaSeparatedList(
1488  if (parser.parseOperand(operands.emplace_back()) ||
1489  parser.parseColonType(types.emplace_back()))
1490  return failure();
1491  return success();
1492  })))
1493  return failure();
1494 
1495  seg.push_back(operands.size() - crtOperandsSize);
1496 
1497  if (failed(parser.parseRBrace()))
1498  return failure();
1499 
1500  if (succeeded(parser.parseOptionalLSquare())) {
1501  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1502  parser.parseRSquare())
1503  return failure();
1504  } else {
1505  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1507  }
1508  } while (succeeded(parser.parseOptionalComma()));
1509 
1510  if (failed(parser.parseRParen()))
1511  return failure();
1512 
1513  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1514  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1515  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1516  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1517 
1518  return success();
1519 }
1520 
1521 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1522  if (!hasDeviceTypeValues(attrs))
1523  return false;
1524  if (attrs->size() != 1)
1525  return false;
1526  if (auto deviceTypeAttr =
1527  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1528  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1529  return false;
1530 }
1531 
1533  mlir::OperandRange operands, mlir::TypeRange types,
1534  std::optional<mlir::ArrayAttr> deviceTypes,
1535  std::optional<mlir::DenseI32ArrayAttr> segments,
1536  std::optional<mlir::ArrayAttr> hasDevNum,
1537  std::optional<mlir::ArrayAttr> keywordOnly) {
1538 
1539  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1540  return;
1541 
1542  p << "(";
1543 
1544  printDeviceTypes(p, keywordOnly);
1545  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1546  p << ", ";
1547 
1548  if (hasDeviceTypeValues(deviceTypes)) {
1549  unsigned opIdx = 0;
1550  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1551  p << "{";
1552  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1553  if (boolAttr && boolAttr.getValue())
1554  p << "devnum: ";
1555  llvm::interleaveComma(
1556  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1557  p << operands[opIdx] << " : " << operands[opIdx].getType();
1558  ++opIdx;
1559  });
1560  p << "}";
1561  printSingleDeviceType(p, it.value());
1562  });
1563  }
1564 
1565  p << ")";
1566 }
1567 
1568 static ParseResult parseDeviceTypeOperands(
1569  mlir::OpAsmParser &parser,
1571  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1573  if (failed(parser.parseCommaSeparatedList([&]() {
1574  if (parser.parseOperand(operands.emplace_back()) ||
1575  parser.parseColonType(types.emplace_back()))
1576  return failure();
1577  if (succeeded(parser.parseOptionalLSquare())) {
1578  if (parser.parseAttribute(attributes.emplace_back()) ||
1579  parser.parseRSquare())
1580  return failure();
1581  } else {
1582  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1583  parser.getContext(), mlir::acc::DeviceType::None));
1584  }
1585  return success();
1586  })))
1587  return failure();
1588  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1589  attributes.end());
1590  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1591  return success();
1592 }
1593 
1594 static void
1596  mlir::OperandRange operands, mlir::TypeRange types,
1597  std::optional<mlir::ArrayAttr> deviceTypes) {
1598  if (!hasDeviceTypeValues(deviceTypes))
1599  return;
1600  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1601  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1602  printSingleDeviceType(p, std::get<0>(it));
1603  });
1604 }
1605 
1607  mlir::OpAsmParser &parser,
1609  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1610  mlir::ArrayAttr &keywordOnlyDeviceType) {
1611 
1612  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1613  bool needCommaBeforeOperands = false;
1614 
1615  if (failed(parser.parseOptionalLParen())) {
1616  // Keyword only
1617  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1619  keywordOnlyDeviceType =
1620  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1621  return success();
1622  }
1623 
1624  // Parse keyword only attributes
1625  if (succeeded(parser.parseOptionalLSquare())) {
1626  // Parse keyword only attributes
1627  if (failed(parser.parseCommaSeparatedList([&]() {
1628  if (parser.parseAttribute(
1629  keywordOnlyDeviceTypeAttributes.emplace_back()))
1630  return failure();
1631  return success();
1632  })))
1633  return failure();
1634  if (parser.parseRSquare())
1635  return failure();
1636  needCommaBeforeOperands = true;
1637  }
1638 
1639  if (needCommaBeforeOperands && failed(parser.parseComma()))
1640  return failure();
1641 
1643  if (failed(parser.parseCommaSeparatedList([&]() {
1644  if (parser.parseOperand(operands.emplace_back()) ||
1645  parser.parseColonType(types.emplace_back()))
1646  return failure();
1647  if (succeeded(parser.parseOptionalLSquare())) {
1648  if (parser.parseAttribute(attributes.emplace_back()) ||
1649  parser.parseRSquare())
1650  return failure();
1651  } else {
1652  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1653  parser.getContext(), mlir::acc::DeviceType::None));
1654  }
1655  return success();
1656  })))
1657  return failure();
1658 
1659  if (failed(parser.parseRParen()))
1660  return failure();
1661 
1662  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1663  attributes.end());
1664  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1665  return success();
1666 }
1667 
1670  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1671  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1672 
1673  if (operands.begin() == operands.end() &&
1674  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1675  return;
1676  }
1677 
1678  p << "(";
1679  printDeviceTypes(p, keywordOnlyDeviceTypes);
1680  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1681  hasDeviceTypeValues(deviceTypes))
1682  p << ", ";
1683  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1684  p << ")";
1685 }
1686 
1687 static ParseResult parseOperandWithKeywordOnly(
1688  mlir::OpAsmParser &parser,
1689  std::optional<OpAsmParser::UnresolvedOperand> &operand,
1690  mlir::Type &operandType, mlir::UnitAttr &attr) {
1691  // Keyword only
1692  if (failed(parser.parseOptionalLParen())) {
1693  attr = mlir::UnitAttr::get(parser.getContext());
1694  return success();
1695  }
1696 
1698  if (failed(parser.parseOperand(op)))
1699  return failure();
1700  operand = op;
1701  if (failed(parser.parseColon()))
1702  return failure();
1703  if (failed(parser.parseType(operandType)))
1704  return failure();
1705  if (failed(parser.parseRParen()))
1706  return failure();
1707 
1708  return success();
1709 }
1710 
1712  mlir::Operation *op,
1713  std::optional<mlir::Value> operand,
1714  mlir::Type operandType,
1715  mlir::UnitAttr attr) {
1716  if (attr)
1717  return;
1718 
1719  p << "(";
1720  p.printOperand(*operand);
1721  p << " : ";
1722  p.printType(operandType);
1723  p << ")";
1724 }
1725 
1726 static ParseResult parseOperandsWithKeywordOnly(
1727  mlir::OpAsmParser &parser,
1729  llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
1730  // Keyword only
1731  if (failed(parser.parseOptionalLParen())) {
1732  attr = mlir::UnitAttr::get(parser.getContext());
1733  return success();
1734  }
1735 
1736  if (failed(parser.parseCommaSeparatedList([&]() {
1737  if (parser.parseOperand(operands.emplace_back()))
1738  return failure();
1739  return success();
1740  })))
1741  return failure();
1742  if (failed(parser.parseColon()))
1743  return failure();
1744  if (failed(parser.parseCommaSeparatedList([&]() {
1745  if (parser.parseType(types.emplace_back()))
1746  return failure();
1747  return success();
1748  })))
1749  return failure();
1750  if (failed(parser.parseRParen()))
1751  return failure();
1752 
1753  return success();
1754 }
1755 
1757  mlir::Operation *op,
1758  mlir::OperandRange operands,
1759  mlir::TypeRange types,
1760  mlir::UnitAttr attr) {
1761  if (attr)
1762  return;
1763 
1764  p << "(";
1765  llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
1766  p << " : ";
1767  llvm::interleaveComma(types, p, [&](auto it) { p << it; });
1768  p << ")";
1769 }
1770 
1771 static ParseResult
1773  mlir::acc::CombinedConstructsTypeAttr &attr) {
1774  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1776  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1777  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1779  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1780  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1782  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1783  } else {
1784  parser.emitError(parser.getCurrentLocation(),
1785  "expected compute construct name");
1786  return failure();
1787  }
1788  return success();
1789 }
1790 
1791 static void
1793  mlir::acc::CombinedConstructsTypeAttr attr) {
1794  if (attr) {
1795  switch (attr.getValue()) {
1796  case mlir::acc::CombinedConstructsType::KernelsLoop:
1797  p << "kernels";
1798  break;
1799  case mlir::acc::CombinedConstructsType::ParallelLoop:
1800  p << "parallel";
1801  break;
1802  case mlir::acc::CombinedConstructsType::SerialLoop:
1803  p << "serial";
1804  break;
1805  };
1806  }
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // SerialOp
1811 //===----------------------------------------------------------------------===//
1812 
1813 unsigned SerialOp::getNumDataOperands() {
1814  return getReductionOperands().size() + getPrivateOperands().size() +
1815  getFirstprivateOperands().size() + getDataClauseOperands().size();
1816 }
1817 
1818 Value SerialOp::getDataOperand(unsigned i) {
1819  unsigned numOptional = getAsyncOperands().size();
1820  numOptional += getIfCond() ? 1 : 0;
1821  numOptional += getSelfCond() ? 1 : 0;
1822  return getOperand(getWaitOperands().size() + numOptional + i);
1823 }
1824 
1825 bool acc::SerialOp::hasAsyncOnly() {
1826  return hasAsyncOnly(mlir::acc::DeviceType::None);
1827 }
1828 
1829 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1830  return hasDeviceType(getAsyncOnly(), deviceType);
1831 }
1832 
1833 mlir::Value acc::SerialOp::getAsyncValue() {
1834  return getAsyncValue(mlir::acc::DeviceType::None);
1835 }
1836 
1837 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1839  getAsyncOperands(), deviceType);
1840 }
1841 
1842 bool acc::SerialOp::hasWaitOnly() {
1843  return hasWaitOnly(mlir::acc::DeviceType::None);
1844 }
1845 
1846 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1847  return hasDeviceType(getWaitOnly(), deviceType);
1848 }
1849 
1850 mlir::Operation::operand_range SerialOp::getWaitValues() {
1851  return getWaitValues(mlir::acc::DeviceType::None);
1852 }
1853 
1855 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1857  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1858  getHasWaitDevnum(), deviceType);
1859 }
1860 
1861 mlir::Value SerialOp::getWaitDevnum() {
1862  return getWaitDevnum(mlir::acc::DeviceType::None);
1863 }
1864 
1865 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1866  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1867  getWaitOperandsSegments(), getHasWaitDevnum(),
1868  deviceType);
1869 }
1870 
1871 LogicalResult acc::SerialOp::verify() {
1872  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1873  *this, getPrivatizations(), getPrivateOperands(), "private",
1874  "privatizations", /*checkOperandType=*/false)))
1875  return failure();
1876  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1877  *this, getFirstprivatizations(), getFirstprivateOperands(),
1878  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1879  return failure();
1880  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1881  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1882  "reductions", false)))
1883  return failure();
1884 
1886  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1887  getWaitOperandsDeviceTypeAttr(), "wait")))
1888  return failure();
1889 
1890  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1891  getAsyncOperandsDeviceTypeAttr(),
1892  "async")))
1893  return failure();
1894 
1895  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1896  return failure();
1897 
1898  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1899 }
1900 
1901 void acc::SerialOp::addAsyncOnly(
1902  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1903  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1904  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1905 }
1906 
1907 void acc::SerialOp::addAsyncOperand(
1908  MLIRContext *context, mlir::Value newValue,
1909  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1910  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912  getAsyncOperandsMutable()));
1913 }
1914 
1915 void acc::SerialOp::addWaitOnly(
1916  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1917  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1918  effectiveDeviceTypes));
1919 }
1920 void acc::SerialOp::addWaitOperands(
1921  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1922  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1923 
1924  llvm::SmallVector<int32_t> segments;
1925  if (getWaitOperandsSegments())
1926  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1927 
1928  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1929  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1930  getWaitOperandsMutable(), segments));
1931  setWaitOperandsSegments(segments);
1932 
1934  if (getHasWaitDevnumAttr())
1935  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1936  hasDevnums.insert(
1937  hasDevnums.end(),
1938  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1939  mlir::BoolAttr::get(context, hasDevnum));
1940  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // KernelsOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 unsigned KernelsOp::getNumDataOperands() {
1948  return getDataClauseOperands().size();
1949 }
1950 
1951 Value KernelsOp::getDataOperand(unsigned i) {
1952  unsigned numOptional = getAsyncOperands().size();
1953  numOptional += getWaitOperands().size();
1954  numOptional += getNumGangs().size();
1955  numOptional += getNumWorkers().size();
1956  numOptional += getVectorLength().size();
1957  numOptional += getIfCond() ? 1 : 0;
1958  numOptional += getSelfCond() ? 1 : 0;
1959  return getOperand(numOptional + i);
1960 }
1961 
1962 bool acc::KernelsOp::hasAsyncOnly() {
1963  return hasAsyncOnly(mlir::acc::DeviceType::None);
1964 }
1965 
1966 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1967  return hasDeviceType(getAsyncOnly(), deviceType);
1968 }
1969 
1970 mlir::Value acc::KernelsOp::getAsyncValue() {
1971  return getAsyncValue(mlir::acc::DeviceType::None);
1972 }
1973 
1974 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1976  getAsyncOperands(), deviceType);
1977 }
1978 
1979 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1980  return getNumWorkersValue(mlir::acc::DeviceType::None);
1981 }
1982 
1984 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1985  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1986  deviceType);
1987 }
1988 
1989 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1990  return getVectorLengthValue(mlir::acc::DeviceType::None);
1991 }
1992 
1994 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1995  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1996  getVectorLength(), deviceType);
1997 }
1998 
1999 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2000  return getNumGangsValues(mlir::acc::DeviceType::None);
2001 }
2002 
2004 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2005  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2006  getNumGangsSegments(), deviceType);
2007 }
2008 
2009 bool acc::KernelsOp::hasWaitOnly() {
2010  return hasWaitOnly(mlir::acc::DeviceType::None);
2011 }
2012 
2013 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2014  return hasDeviceType(getWaitOnly(), deviceType);
2015 }
2016 
2017 mlir::Operation::operand_range KernelsOp::getWaitValues() {
2018  return getWaitValues(mlir::acc::DeviceType::None);
2019 }
2020 
2022 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2024  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2025  getHasWaitDevnum(), deviceType);
2026 }
2027 
2028 mlir::Value KernelsOp::getWaitDevnum() {
2029  return getWaitDevnum(mlir::acc::DeviceType::None);
2030 }
2031 
2032 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2033  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2034  getWaitOperandsSegments(), getHasWaitDevnum(),
2035  deviceType);
2036 }
2037 
2038 LogicalResult acc::KernelsOp::verify() {
2040  *this, getNumGangs(), getNumGangsSegmentsAttr(),
2041  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2042  return failure();
2043 
2045  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2046  getWaitOperandsDeviceTypeAttr(), "wait")))
2047  return failure();
2048 
2049  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2050  getNumWorkersDeviceTypeAttr(),
2051  "num_workers")))
2052  return failure();
2053 
2054  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2055  getVectorLengthDeviceTypeAttr(),
2056  "vector_length")))
2057  return failure();
2058 
2059  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2060  getAsyncOperandsDeviceTypeAttr(),
2061  "async")))
2062  return failure();
2063 
2064  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
2065  return failure();
2066 
2067  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2068 }
2069 
2070 void acc::KernelsOp::addNumWorkersOperand(
2071  MLIRContext *context, mlir::Value newValue,
2072  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2073  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2074  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2075  getNumWorkersMutable()));
2076 }
2077 
2078 void acc::KernelsOp::addVectorLengthOperand(
2079  MLIRContext *context, mlir::Value newValue,
2080  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2081  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2082  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2083  getVectorLengthMutable()));
2084 }
2085 void acc::KernelsOp::addAsyncOnly(
2086  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2087  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2088  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2089 }
2090 
2091 void acc::KernelsOp::addAsyncOperand(
2092  MLIRContext *context, mlir::Value newValue,
2093  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2094  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2095  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2096  getAsyncOperandsMutable()));
2097 }
2098 
2099 void acc::KernelsOp::addNumGangsOperands(
2100  MLIRContext *context, mlir::ValueRange newValues,
2101  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2102  llvm::SmallVector<int32_t> segments;
2103  if (getNumGangsSegmentsAttr())
2104  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2105 
2106  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2107  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2108  getNumGangsMutable(), segments));
2109 
2110  setNumGangsSegments(segments);
2111 }
2112 
2113 void acc::KernelsOp::addWaitOnly(
2114  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2115  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2116  effectiveDeviceTypes));
2117 }
2118 void acc::KernelsOp::addWaitOperands(
2119  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2120  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2121 
2122  llvm::SmallVector<int32_t> segments;
2123  if (getWaitOperandsSegments())
2124  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2125 
2126  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2127  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2128  getWaitOperandsMutable(), segments));
2129  setWaitOperandsSegments(segments);
2130 
2132  if (getHasWaitDevnumAttr())
2133  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2134  hasDevnums.insert(
2135  hasDevnums.end(),
2136  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2137  mlir::BoolAttr::get(context, hasDevnum));
2138  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2139 }
2140 
2141 //===----------------------------------------------------------------------===//
2142 // HostDataOp
2143 //===----------------------------------------------------------------------===//
2144 
2145 LogicalResult acc::HostDataOp::verify() {
2146  if (getDataClauseOperands().empty())
2147  return emitError("at least one operand must appear on the host_data "
2148  "operation");
2149 
2150  for (mlir::Value operand : getDataClauseOperands())
2151  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2152  return emitError("expect data entry operation as defining op");
2153  return success();
2154 }
2155 
2156 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2157  MLIRContext *context) {
2158  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2159 }
2160 
2161 //===----------------------------------------------------------------------===//
2162 // LoopOp
2163 //===----------------------------------------------------------------------===//
2164 
2165 static ParseResult parseGangValue(
2166  OpAsmParser &parser, llvm::StringRef keyword,
2169  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2170  bool &needCommaBetweenValues, bool &newValue) {
2171  if (succeeded(parser.parseOptionalKeyword(keyword))) {
2172  if (parser.parseEqual())
2173  return failure();
2174  if (parser.parseOperand(operands.emplace_back()) ||
2175  parser.parseColonType(types.emplace_back()))
2176  return failure();
2177  attributes.push_back(gangArgType);
2178  needCommaBetweenValues = true;
2179  newValue = true;
2180  }
2181  return success();
2182 }
2183 
2184 static ParseResult parseGangClause(
2185  OpAsmParser &parser,
2187  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2188  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2189  mlir::ArrayAttr &gangOnlyDeviceType) {
2190  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2191  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2192  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2194  bool needCommaBetweenValues = false;
2195  bool needCommaBeforeOperands = false;
2196 
2197  if (failed(parser.parseOptionalLParen())) {
2198  // Gang only keyword
2199  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2201  gangOnlyDeviceType =
2202  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2203  return success();
2204  }
2205 
2206  // Parse gang only attributes
2207  if (succeeded(parser.parseOptionalLSquare())) {
2208  // Parse gang only attributes
2209  if (failed(parser.parseCommaSeparatedList([&]() {
2210  if (parser.parseAttribute(
2211  gangOnlyDeviceTypeAttributes.emplace_back()))
2212  return failure();
2213  return success();
2214  })))
2215  return failure();
2216  if (parser.parseRSquare())
2217  return failure();
2218  needCommaBeforeOperands = true;
2219  }
2220 
2221  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2222  mlir::acc::GangArgType::Num);
2223  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2224  mlir::acc::GangArgType::Dim);
2225  auto argStatic = mlir::acc::GangArgTypeAttr::get(
2226  parser.getContext(), mlir::acc::GangArgType::Static);
2227 
2228  do {
2229  if (needCommaBeforeOperands) {
2230  needCommaBeforeOperands = false;
2231  continue;
2232  }
2233 
2234  if (failed(parser.parseLBrace()))
2235  return failure();
2236 
2237  int32_t crtOperandsSize = gangOperands.size();
2238  while (true) {
2239  bool newValue = false;
2240  bool needValue = false;
2241  if (needCommaBetweenValues) {
2242  if (succeeded(parser.parseOptionalComma()))
2243  needValue = true; // expect a new value after comma.
2244  else
2245  break;
2246  }
2247 
2248  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2249  gangOperands, gangOperandsType,
2250  gangArgTypeAttributes, argNum,
2251  needCommaBetweenValues, newValue)))
2252  return failure();
2253  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2254  gangOperands, gangOperandsType,
2255  gangArgTypeAttributes, argDim,
2256  needCommaBetweenValues, newValue)))
2257  return failure();
2258  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2259  gangOperands, gangOperandsType,
2260  gangArgTypeAttributes, argStatic,
2261  needCommaBetweenValues, newValue)))
2262  return failure();
2263 
2264  if (!newValue && needValue) {
2265  parser.emitError(parser.getCurrentLocation(),
2266  "new value expected after comma");
2267  return failure();
2268  }
2269 
2270  if (!newValue)
2271  break;
2272  }
2273 
2274  if (gangOperands.empty())
2275  return parser.emitError(
2276  parser.getCurrentLocation(),
2277  "expect at least one of num, dim or static values");
2278 
2279  if (failed(parser.parseRBrace()))
2280  return failure();
2281 
2282  if (succeeded(parser.parseOptionalLSquare())) {
2283  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2284  parser.parseRSquare())
2285  return failure();
2286  } else {
2287  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2289  }
2290 
2291  seg.push_back(gangOperands.size() - crtOperandsSize);
2292 
2293  } while (succeeded(parser.parseOptionalComma()));
2294 
2295  if (failed(parser.parseRParen()))
2296  return failure();
2297 
2298  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2299  gangArgTypeAttributes.end());
2300  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2301  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2302 
2304  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2305  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2306 
2307  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2308  return success();
2309 }
2310 
2312  mlir::OperandRange operands, mlir::TypeRange types,
2313  std::optional<mlir::ArrayAttr> gangArgTypes,
2314  std::optional<mlir::ArrayAttr> deviceTypes,
2315  std::optional<mlir::DenseI32ArrayAttr> segments,
2316  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2317 
2318  if (operands.begin() == operands.end() &&
2319  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2320  return;
2321  }
2322 
2323  p << "(";
2324 
2325  printDeviceTypes(p, gangOnlyDeviceTypes);
2326 
2327  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2328  hasDeviceTypeValues(deviceTypes))
2329  p << ", ";
2330 
2331  if (hasDeviceTypeValues(deviceTypes)) {
2332  unsigned opIdx = 0;
2333  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2334  p << "{";
2335  llvm::interleaveComma(
2336  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2337  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2338  (*gangArgTypes)[opIdx]);
2339  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2340  p << LoopOp::getGangNumKeyword();
2341  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2342  p << LoopOp::getGangDimKeyword();
2343  else if (gangArgTypeAttr.getValue() ==
2344  mlir::acc::GangArgType::Static)
2345  p << LoopOp::getGangStaticKeyword();
2346  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2347  ++opIdx;
2348  });
2349  p << "}";
2350  printSingleDeviceType(p, it.value());
2351  });
2352  }
2353  p << ")";
2354 }
2355 
2357  std::optional<mlir::ArrayAttr> segments,
2358  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2359  if (!segments)
2360  return false;
2361  for (auto attr : *segments) {
2362  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2363  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2364  return true;
2365  }
2366  return false;
2367 }
2368 
2369 /// Check for duplicates in the DeviceType array attribute.
2370 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2371  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2372  if (!deviceTypes)
2373  return success();
2374  for (auto attr : deviceTypes) {
2375  auto deviceTypeAttr =
2376  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2377  if (!deviceTypeAttr)
2378  return failure();
2379  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2380  return failure();
2381  }
2382  return success();
2383 }
2384 
2385 LogicalResult acc::LoopOp::verify() {
2386  if (getUpperbound().size() != getStep().size())
2387  return emitError() << "number of upperbounds expected to be the same as "
2388  "number of steps";
2389 
2390  if (getUpperbound().size() != getLowerbound().size())
2391  return emitError() << "number of upperbounds expected to be the same as "
2392  "number of lowerbounds";
2393 
2394  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2395  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2396  return emitError() << "inclusiveUpperbound size is expected to be the same"
2397  << " as upperbound size";
2398 
2399  // Check collapse
2400  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2401  return emitOpError() << "collapse device_type attr must be define when"
2402  << " collapse attr is present";
2403 
2404  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2405  getCollapseAttr().getValue().size() !=
2406  getCollapseDeviceTypeAttr().getValue().size())
2407  return emitOpError() << "collapse attribute count must match collapse"
2408  << " device_type count";
2409  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2410  return emitOpError()
2411  << "duplicate device_type found in collapseDeviceType attribute";
2412 
2413  // Check gang
2414  if (!getGangOperands().empty()) {
2415  if (!getGangOperandsArgType())
2416  return emitOpError() << "gangOperandsArgType attribute must be defined"
2417  << " when gang operands are present";
2418 
2419  if (getGangOperands().size() !=
2420  getGangOperandsArgTypeAttr().getValue().size())
2421  return emitOpError() << "gangOperandsArgType attribute count must match"
2422  << " gangOperands count";
2423  }
2424  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2425  return emitOpError() << "duplicate device_type found in gang attribute";
2426 
2428  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2429  getGangOperandsDeviceTypeAttr(), "gang")))
2430  return failure();
2431 
2432  // Check worker
2433  if (failed(checkDeviceTypes(getWorkerAttr())))
2434  return emitOpError() << "duplicate device_type found in worker attribute";
2435  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2436  return emitOpError() << "duplicate device_type found in "
2437  "workerNumOperandsDeviceType attribute";
2438  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2439  getWorkerNumOperandsDeviceTypeAttr(),
2440  "worker")))
2441  return failure();
2442 
2443  // Check vector
2444  if (failed(checkDeviceTypes(getVectorAttr())))
2445  return emitOpError() << "duplicate device_type found in vector attribute";
2446  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2447  return emitOpError() << "duplicate device_type found in "
2448  "vectorOperandsDeviceType attribute";
2449  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2450  getVectorOperandsDeviceTypeAttr(),
2451  "vector")))
2452  return failure();
2453 
2455  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2456  getTileOperandsDeviceTypeAttr(), "tile")))
2457  return failure();
2458 
2459  // auto, independent and seq attribute are mutually exclusive.
2460  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2461  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2462  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2463  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2464  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
2465  << "\", " << getIndependentAttrName() << ", "
2466  << getSeqAttrName()
2467  << " can be present at the same time";
2468  }
2469 
2470  // Gang, worker and vector are incompatible with seq.
2471  if (getSeqAttr()) {
2472  for (auto attr : getSeqAttr()) {
2473  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2474  if (hasVector(deviceTypeAttr.getValue()) ||
2475  getVectorValue(deviceTypeAttr.getValue()) ||
2476  hasWorker(deviceTypeAttr.getValue()) ||
2477  getWorkerValue(deviceTypeAttr.getValue()) ||
2478  hasGang(deviceTypeAttr.getValue()) ||
2479  getGangValue(mlir::acc::GangArgType::Num,
2480  deviceTypeAttr.getValue()) ||
2481  getGangValue(mlir::acc::GangArgType::Dim,
2482  deviceTypeAttr.getValue()) ||
2483  getGangValue(mlir::acc::GangArgType::Static,
2484  deviceTypeAttr.getValue()))
2485  return emitError()
2486  << "gang, worker or vector cannot appear with the seq attr";
2487  }
2488  }
2489 
2490  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2491  *this, getPrivatizations(), getPrivateOperands(), "private",
2492  "privatizations", false)))
2493  return failure();
2494 
2495  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2496  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2497  "reductions", false)))
2498  return failure();
2499 
2500  if (getCombined().has_value() &&
2501  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2502  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2503  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2504  return emitError("unexpected combined constructs attribute");
2505  }
2506 
2507  // Check non-empty body().
2508  if (getRegion().empty())
2509  return emitError("expected non-empty body.");
2510 
2511  // When it is container-like - it is expected to hold a loop-like operation.
2512  if (isContainerLike()) {
2513  // Obtain the maximum collapse count - we use this to check that there
2514  // are enough loops contained.
2515  uint64_t collapseCount = getCollapseValue().value_or(1);
2516  if (getCollapseAttr()) {
2517  for (auto collapseEntry : getCollapseAttr()) {
2518  auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2519  if (intAttr.getValue().getZExtValue() > collapseCount)
2520  collapseCount = intAttr.getValue().getZExtValue();
2521  }
2522  }
2523 
2524  // We want to check that we find enough loop-like operations inside.
2525  // PreOrder walk allows us to walk in a breadth-first manner at each nesting
2526  // level.
2527  mlir::Operation *expectedParent = this->getOperation();
2528  bool foundSibling = false;
2529  getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
2530  if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2531  // This effectively checks that we are not looking at a sibling loop.
2532  if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2533  expectedParent) {
2534  foundSibling = true;
2535  return mlir::WalkResult::interrupt();
2536  }
2537 
2538  collapseCount--;
2539  expectedParent = op;
2540  }
2541  // We found enough contained loops.
2542  if (collapseCount == 0)
2543  return mlir::WalkResult::interrupt();
2544  return mlir::WalkResult::advance();
2545  });
2546 
2547  if (foundSibling)
2548  return emitError("found sibling loops inside container-like acc.loop");
2549  if (collapseCount != 0)
2550  return emitError("failed to find enough loop-like operations inside "
2551  "container-like acc.loop");
2552  }
2553 
2554  return success();
2555 }
2556 
2557 unsigned LoopOp::getNumDataOperands() {
2558  return getReductionOperands().size() + getPrivateOperands().size();
2559 }
2560 
2561 Value LoopOp::getDataOperand(unsigned i) {
2562  unsigned numOptional =
2563  getLowerbound().size() + getUpperbound().size() + getStep().size();
2564  numOptional += getGangOperands().size();
2565  numOptional += getVectorOperands().size();
2566  numOptional += getWorkerNumOperands().size();
2567  numOptional += getTileOperands().size();
2568  numOptional += getCacheOperands().size();
2569  return getOperand(numOptional + i);
2570 }
2571 
2572 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2573 
2574 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2575  return hasDeviceType(getAuto_(), deviceType);
2576 }
2577 
2578 bool LoopOp::hasIndependent() {
2579  return hasIndependent(mlir::acc::DeviceType::None);
2580 }
2581 
2582 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2583  return hasDeviceType(getIndependent(), deviceType);
2584 }
2585 
2586 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2587 
2588 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2589  return hasDeviceType(getSeq(), deviceType);
2590 }
2591 
2592 mlir::Value LoopOp::getVectorValue() {
2593  return getVectorValue(mlir::acc::DeviceType::None);
2594 }
2595 
2596 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2597  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2598  getVectorOperands(), deviceType);
2599 }
2600 
2601 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2602 
2603 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2604  return hasDeviceType(getVector(), deviceType);
2605 }
2606 
2607 mlir::Value LoopOp::getWorkerValue() {
2608  return getWorkerValue(mlir::acc::DeviceType::None);
2609 }
2610 
2611 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2612  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2613  getWorkerNumOperands(), deviceType);
2614 }
2615 
2616 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2617 
2618 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2619  return hasDeviceType(getWorker(), deviceType);
2620 }
2621 
2622 mlir::Operation::operand_range LoopOp::getTileValues() {
2623  return getTileValues(mlir::acc::DeviceType::None);
2624 }
2625 
2627 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2628  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2629  getTileOperandsSegments(), deviceType);
2630 }
2631 
2632 std::optional<int64_t> LoopOp::getCollapseValue() {
2633  return getCollapseValue(mlir::acc::DeviceType::None);
2634 }
2635 
2636 std::optional<int64_t>
2637 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2638  if (!getCollapseAttr())
2639  return std::nullopt;
2640  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2641  auto intAttr =
2642  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2643  return intAttr.getValue().getZExtValue();
2644  }
2645  return std::nullopt;
2646 }
2647 
2648 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2649  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2650 }
2651 
2652 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2653  mlir::acc::DeviceType deviceType) {
2654  if (getGangOperands().empty())
2655  return {};
2656  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2657  int32_t nbOperandsBefore = 0;
2658  for (unsigned i = 0; i < *pos; ++i)
2659  nbOperandsBefore += (*getGangOperandsSegments())[i];
2661  getGangOperands()
2662  .drop_front(nbOperandsBefore)
2663  .take_front((*getGangOperandsSegments())[*pos]);
2664 
2665  int32_t argTypeIdx = nbOperandsBefore;
2666  for (auto value : values) {
2667  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2668  (*getGangOperandsArgType())[argTypeIdx]);
2669  if (gangArgTypeAttr.getValue() == gangArgType)
2670  return value;
2671  ++argTypeIdx;
2672  }
2673  }
2674  return {};
2675 }
2676 
2677 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2678 
2679 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2680  return hasDeviceType(getGang(), deviceType);
2681 }
2682 
2683 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2684  return {&getRegion()};
2685 }
2686 
2687 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2688 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2689 /// `(` ssa-id-and-type-list `)`
2690 /// region
2691 ParseResult
2694  SmallVectorImpl<Type> &lowerboundType,
2696  SmallVectorImpl<Type> &upperboundType,
2698  SmallVectorImpl<Type> &stepType) {
2699 
2700  SmallVector<OpAsmParser::Argument> inductionVars;
2701  if (succeeded(
2702  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2703  if (parser.parseLParen() ||
2704  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2705  /*allowType=*/true) ||
2706  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2707  parser.parseOperandList(lowerbound, inductionVars.size(),
2709  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2710  parser.parseKeyword("to") || parser.parseLParen() ||
2711  parser.parseOperandList(upperbound, inductionVars.size(),
2713  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2714  parser.parseKeyword("step") || parser.parseLParen() ||
2715  parser.parseOperandList(step, inductionVars.size(),
2717  parser.parseColonTypeList(stepType) || parser.parseRParen())
2718  return failure();
2719  }
2720  return parser.parseRegion(region, inductionVars);
2721 }
2722 
2724  ValueRange lowerbound, TypeRange lowerboundType,
2725  ValueRange upperbound, TypeRange upperboundType,
2726  ValueRange steps, TypeRange stepType) {
2727  ValueRange regionArgs = region.front().getArguments();
2728  if (!regionArgs.empty()) {
2729  p << acc::LoopOp::getControlKeyword() << "(";
2730  llvm::interleaveComma(regionArgs, p,
2731  [&p](Value v) { p << v << " : " << v.getType(); });
2732  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2733  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2734  << " : " << stepType << ") ";
2735  }
2736  p.printRegion(region, /*printEntryBlockArgs=*/false);
2737 }
2738 
2739 void acc::LoopOp::addSeq(MLIRContext *context,
2740  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2741  setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2742  effectiveDeviceTypes));
2743 }
2744 
2745 void acc::LoopOp::addIndependent(
2746  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2747  setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2748  context, getIndependentAttr(), effectiveDeviceTypes));
2749 }
2750 
2751 void acc::LoopOp::addAuto(MLIRContext *context,
2752  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2753  setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2754  effectiveDeviceTypes));
2755 }
2756 
2757 void acc::LoopOp::setCollapseForDeviceTypes(
2758  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2759  llvm::APInt value) {
2761  llvm::SmallVector<mlir::Attribute> newDeviceTypes;
2762 
2763  assert((getCollapseAttr() == nullptr) ==
2764  (getCollapseDeviceTypeAttr() == nullptr));
2765  assert(value.getBitWidth() == 64);
2766 
2767  if (getCollapseAttr()) {
2768  for (const auto &existing :
2769  llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2770  newValues.push_back(std::get<0>(existing));
2771  newDeviceTypes.push_back(std::get<1>(existing));
2772  }
2773  }
2774 
2775  if (effectiveDeviceTypes.empty()) {
2776  // If the effective device-types list is empty, this is before there are any
2777  // being applied by device_type, so this should be added as a 'none'.
2778  newValues.push_back(
2779  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2780  newDeviceTypes.push_back(
2782  } else {
2783  for (DeviceType DT : effectiveDeviceTypes) {
2784  newValues.push_back(
2785  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2786  newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
2787  }
2788  }
2789 
2790  setCollapseAttr(ArrayAttr::get(context, newValues));
2791  setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
2792 }
2793 
2794 void acc::LoopOp::setTileForDeviceTypes(
2795  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2796  ValueRange values) {
2797  llvm::SmallVector<int32_t> segments;
2798  if (getTileOperandsSegments())
2799  llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2800 
2801  setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2802  context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2803  getTileOperandsMutable(), segments));
2804 
2805  setTileOperandsSegments(segments);
2806 }
2807 
2808 void acc::LoopOp::addVectorOperand(
2809  MLIRContext *context, mlir::Value newValue,
2810  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2811  setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2812  context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2813  newValue, getVectorOperandsMutable()));
2814 }
2815 
2816 void acc::LoopOp::addEmptyVector(
2817  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2818  setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
2819  effectiveDeviceTypes));
2820 }
2821 
2822 void acc::LoopOp::addWorkerNumOperand(
2823  MLIRContext *context, mlir::Value newValue,
2824  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2825  setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2826  context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2827  newValue, getWorkerNumOperandsMutable()));
2828 }
2829 
2830 void acc::LoopOp::addEmptyWorker(
2831  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2832  setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
2833  effectiveDeviceTypes));
2834 }
2835 
2836 void acc::LoopOp::addEmptyGang(
2837  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2838  setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
2839  effectiveDeviceTypes));
2840 }
2841 
2842 void acc::LoopOp::addGangOperands(
2843  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2844  llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
2845  llvm::SmallVector<int32_t> segments;
2846  if (std::optional<ArrayRef<int32_t>> existingSegments =
2847  getGangOperandsSegments())
2848  llvm::copy(*existingSegments, std::back_inserter(segments));
2849 
2850  unsigned beforeCount = segments.size();
2851 
2852  setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2853  context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2854  getGangOperandsMutable(), segments));
2855 
2856  setGangOperandsSegments(segments);
2857 
2858  // This is a bit of extra work to make sure we update the 'types' correctly by
2859  // adding to the types collection the correct number of times. We could
2860  // potentially add something similar to the
2861  // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
2862  // excessive for a one-off case.
2863  unsigned numAdded = segments.size() - beforeCount;
2864 
2865  if (numAdded > 0) {
2867  if (getGangOperandsArgTypeAttr())
2868  llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
2869 
2870  for (auto i : llvm::index_range(0u, numAdded)) {
2871  llvm::transform(argTypes, std::back_inserter(gangTypes),
2872  [=](mlir::acc::GangArgType gangTy) {
2873  return mlir::acc::GangArgTypeAttr::get(context, gangTy);
2874  });
2875  (void)i;
2876  }
2877 
2878  setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
2879  }
2880 }
2881 
2882 //===----------------------------------------------------------------------===//
2883 // DataOp
2884 //===----------------------------------------------------------------------===//
2885 
2886 LogicalResult acc::DataOp::verify() {
2887  // 2.6.5. Data Construct restriction
2888  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2889  // attach, or default clause must appear on a data construct.
2890  if (getOperands().empty() && !getDefaultAttr())
2891  return emitError("at least one operand or the default attribute "
2892  "must appear on the data operation");
2893 
2894  for (mlir::Value operand : getDataClauseOperands())
2895  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2896  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2897  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2898  operand.getDefiningOp()))
2899  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2900  "as defining op");
2901 
2902  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2903  return failure();
2904 
2905  return success();
2906 }
2907 
2908 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2909 
2910 Value DataOp::getDataOperand(unsigned i) {
2911  unsigned numOptional = getIfCond() ? 1 : 0;
2912  numOptional += getAsyncOperands().size() ? 1 : 0;
2913  numOptional += getWaitOperands().size();
2914  return getOperand(numOptional + i);
2915 }
2916 
2917 bool acc::DataOp::hasAsyncOnly() {
2918  return hasAsyncOnly(mlir::acc::DeviceType::None);
2919 }
2920 
2921 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2922  return hasDeviceType(getAsyncOnly(), deviceType);
2923 }
2924 
2925 mlir::Value DataOp::getAsyncValue() {
2926  return getAsyncValue(mlir::acc::DeviceType::None);
2927 }
2928 
2929 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2931  getAsyncOperands(), deviceType);
2932 }
2933 
2934 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2935 
2936 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2937  return hasDeviceType(getWaitOnly(), deviceType);
2938 }
2939 
2940 mlir::Operation::operand_range DataOp::getWaitValues() {
2941  return getWaitValues(mlir::acc::DeviceType::None);
2942 }
2943 
2945 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2947  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2948  getHasWaitDevnum(), deviceType);
2949 }
2950 
2951 mlir::Value DataOp::getWaitDevnum() {
2952  return getWaitDevnum(mlir::acc::DeviceType::None);
2953 }
2954 
2955 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2956  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2957  getWaitOperandsSegments(), getHasWaitDevnum(),
2958  deviceType);
2959 }
2960 
2961 void acc::DataOp::addAsyncOnly(
2962  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2963  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2964  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2965 }
2966 
2967 void acc::DataOp::addAsyncOperand(
2968  MLIRContext *context, mlir::Value newValue,
2969  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2970  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2971  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2972  getAsyncOperandsMutable()));
2973 }
2974 
2975 void acc::DataOp::addWaitOnly(MLIRContext *context,
2976  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2977  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2978  effectiveDeviceTypes));
2979 }
2980 
2981 void acc::DataOp::addWaitOperands(
2982  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2983  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2984 
2985  llvm::SmallVector<int32_t> segments;
2986  if (getWaitOperandsSegments())
2987  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2988 
2989  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2990  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2991  getWaitOperandsMutable(), segments));
2992  setWaitOperandsSegments(segments);
2993 
2995  if (getHasWaitDevnumAttr())
2996  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2997  hasDevnums.insert(
2998  hasDevnums.end(),
2999  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3000  mlir::BoolAttr::get(context, hasDevnum));
3001  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3002 }
3003 
3004 //===----------------------------------------------------------------------===//
3005 // ExitDataOp
3006 //===----------------------------------------------------------------------===//
3007 
3008 LogicalResult acc::ExitDataOp::verify() {
3009  // 2.6.6. Data Exit Directive restriction
3010  // At least one copyout, delete, or detach clause must appear on an exit data
3011  // directive.
3012  if (getDataClauseOperands().empty())
3013  return emitError("at least one operand must be present in dataOperands on "
3014  "the exit data operation");
3015 
3016  // The async attribute represent the async clause without value. Therefore the
3017  // attribute and operand cannot appear at the same time.
3018  if (getAsyncOperand() && getAsync())
3019  return emitError("async attribute cannot appear with asyncOperand");
3020 
3021  // The wait attribute represent the wait clause without values. Therefore the
3022  // attribute and operands cannot appear at the same time.
3023  if (!getWaitOperands().empty() && getWait())
3024  return emitError("wait attribute cannot appear with waitOperands");
3025 
3026  if (getWaitDevnum() && getWaitOperands().empty())
3027  return emitError("wait_devnum cannot appear without waitOperands");
3028 
3029  return success();
3030 }
3031 
3032 unsigned ExitDataOp::getNumDataOperands() {
3033  return getDataClauseOperands().size();
3034 }
3035 
3036 Value ExitDataOp::getDataOperand(unsigned i) {
3037  unsigned numOptional = getIfCond() ? 1 : 0;
3038  numOptional += getAsyncOperand() ? 1 : 0;
3039  numOptional += getWaitDevnum() ? 1 : 0;
3040  return getOperand(getWaitOperands().size() + numOptional + i);
3041 }
3042 
3043 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3044  MLIRContext *context) {
3045  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3046 }
3047 
3048 //===----------------------------------------------------------------------===//
3049 // EnterDataOp
3050 //===----------------------------------------------------------------------===//
3051 
3052 LogicalResult acc::EnterDataOp::verify() {
3053  // 2.6.6. Data Enter Directive restriction
3054  // At least one copyin, create, or attach clause must appear on an enter data
3055  // directive.
3056  if (getDataClauseOperands().empty())
3057  return emitError("at least one operand must be present in dataOperands on "
3058  "the enter data operation");
3059 
3060  // The async attribute represent the async clause without value. Therefore the
3061  // attribute and operand cannot appear at the same time.
3062  if (getAsyncOperand() && getAsync())
3063  return emitError("async attribute cannot appear with asyncOperand");
3064 
3065  // The wait attribute represent the wait clause without values. Therefore the
3066  // attribute and operands cannot appear at the same time.
3067  if (!getWaitOperands().empty() && getWait())
3068  return emitError("wait attribute cannot appear with waitOperands");
3069 
3070  if (getWaitDevnum() && getWaitOperands().empty())
3071  return emitError("wait_devnum cannot appear without waitOperands");
3072 
3073  for (mlir::Value operand : getDataClauseOperands())
3074  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3075  operand.getDefiningOp()))
3076  return emitError("expect data entry operation as defining op");
3077 
3078  return success();
3079 }
3080 
3081 unsigned EnterDataOp::getNumDataOperands() {
3082  return getDataClauseOperands().size();
3083 }
3084 
3085 Value EnterDataOp::getDataOperand(unsigned i) {
3086  unsigned numOptional = getIfCond() ? 1 : 0;
3087  numOptional += getAsyncOperand() ? 1 : 0;
3088  numOptional += getWaitDevnum() ? 1 : 0;
3089  return getOperand(getWaitOperands().size() + numOptional + i);
3090 }
3091 
3092 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3093  MLIRContext *context) {
3094  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3095 }
3096 
3097 //===----------------------------------------------------------------------===//
3098 // AtomicReadOp
3099 //===----------------------------------------------------------------------===//
3100 
3101 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3102 
3103 //===----------------------------------------------------------------------===//
3104 // AtomicWriteOp
3105 //===----------------------------------------------------------------------===//
3106 
3107 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3108 
3109 //===----------------------------------------------------------------------===//
3110 // AtomicUpdateOp
3111 //===----------------------------------------------------------------------===//
3112 
3113 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3114  PatternRewriter &rewriter) {
3115  if (op.isNoOp()) {
3116  rewriter.eraseOp(op);
3117  return success();
3118  }
3119 
3120  if (Value writeVal = op.getWriteOpVal()) {
3121  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3122  return success();
3123  }
3124 
3125  return failure();
3126 }
3127 
3128 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3129 
3130 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3131 
3132 //===----------------------------------------------------------------------===//
3133 // AtomicCaptureOp
3134 //===----------------------------------------------------------------------===//
3135 
3136 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3137  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3138  return op;
3139  return dyn_cast<AtomicReadOp>(getSecondOp());
3140 }
3141 
3142 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3143  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3144  return op;
3145  return dyn_cast<AtomicWriteOp>(getSecondOp());
3146 }
3147 
3148 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3149  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3150  return op;
3151  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3152 }
3153 
3154 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3155 
3156 //===----------------------------------------------------------------------===//
3157 // DeclareEnterOp
3158 //===----------------------------------------------------------------------===//
3159 
3160 template <typename Op>
3161 static LogicalResult
3163  bool requireAtLeastOneOperand = true) {
3164  if (operands.empty() && requireAtLeastOneOperand)
3165  return emitError(
3166  op->getLoc(),
3167  "at least one operand must appear on the declare operation");
3168 
3169  for (mlir::Value operand : operands) {
3170  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3171  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3172  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3173  operand.getDefiningOp()))
3174  return op.emitError(
3175  "expect valid declare data entry operation or acc.getdeviceptr "
3176  "as defining op");
3177 
3178  mlir::Value var{getVar(operand.getDefiningOp())};
3179  assert(var && "declare operands can only be data entry operations which "
3180  "must have var");
3181  (void)var;
3182  std::optional<mlir::acc::DataClause> dataClauseOptional{
3183  getDataClause(operand.getDefiningOp())};
3184  assert(dataClauseOptional.has_value() &&
3185  "declare operands can only be data entry operations which must have "
3186  "dataClause");
3187  (void)dataClauseOptional;
3188  }
3189 
3190  return success();
3191 }
3192 
3193 LogicalResult acc::DeclareEnterOp::verify() {
3194  return checkDeclareOperands(*this, this->getDataClauseOperands());
3195 }
3196 
3197 //===----------------------------------------------------------------------===//
3198 // DeclareExitOp
3199 //===----------------------------------------------------------------------===//
3200 
3201 LogicalResult acc::DeclareExitOp::verify() {
3202  if (getToken())
3203  return checkDeclareOperands(*this, this->getDataClauseOperands(),
3204  /*requireAtLeastOneOperand=*/false);
3205  return checkDeclareOperands(*this, this->getDataClauseOperands());
3206 }
3207 
3208 //===----------------------------------------------------------------------===//
3209 // DeclareOp
3210 //===----------------------------------------------------------------------===//
3211 
3212 LogicalResult acc::DeclareOp::verify() {
3213  return checkDeclareOperands(*this, this->getDataClauseOperands());
3214 }
3215 
3216 //===----------------------------------------------------------------------===//
3217 // RoutineOp
3218 //===----------------------------------------------------------------------===//
3219 
3220 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3221  acc::DeviceType dtype) {
3222  unsigned parallelism = 0;
3223  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3224  parallelism += op.hasWorker(dtype) ? 1 : 0;
3225  parallelism += op.hasVector(dtype) ? 1 : 0;
3226  parallelism += op.hasSeq(dtype) ? 1 : 0;
3227  return parallelism;
3228 }
3229 
3230 LogicalResult acc::RoutineOp::verify() {
3231  unsigned baseParallelism =
3233 
3234  if (baseParallelism > 1)
3235  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3236  "be present at the same time";
3237 
3238  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3239  ++dtypeInt) {
3240  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3241  if (dtype == acc::DeviceType::None)
3242  continue;
3243  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
3244 
3245  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3246  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3247  "be present at the same time";
3248  }
3249 
3250  return success();
3251 }
3252 
3253 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3254  mlir::ArrayAttr &deviceTypes) {
3255  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3256  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3257 
3258  if (failed(parser.parseCommaSeparatedList([&]() {
3259  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3260  return failure();
3261  if (failed(parser.parseOptionalLSquare())) {
3262  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3263  parser.getContext(), mlir::acc::DeviceType::None));
3264  } else {
3265  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3266  parser.parseRSquare())
3267  return failure();
3268  }
3269  return success();
3270  })))
3271  return failure();
3272 
3273  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3274  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3275 
3276  return success();
3277 }
3278 
3280  std::optional<mlir::ArrayAttr> bindName,
3281  std::optional<mlir::ArrayAttr> deviceTypes) {
3282  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3283  [&](const auto &pair) {
3284  p << std::get<0>(pair);
3285  printSingleDeviceType(p, std::get<1>(pair));
3286  });
3287 }
3288 
3289 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
3290  mlir::ArrayAttr &gang,
3291  mlir::ArrayAttr &gangDim,
3292  mlir::ArrayAttr &gangDimDeviceTypes) {
3293 
3294  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
3295  gangDimDeviceTypeAttrs;
3296  bool needCommaBeforeOperands = false;
3297 
3298  // Gang keyword only
3299  if (failed(parser.parseOptionalLParen())) {
3300  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3302  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3303  return success();
3304  }
3305 
3306  // Parse keyword only attributes
3307  if (succeeded(parser.parseOptionalLSquare())) {
3308  if (failed(parser.parseCommaSeparatedList([&]() {
3309  if (parser.parseAttribute(gangAttrs.emplace_back()))
3310  return failure();
3311  return success();
3312  })))
3313  return failure();
3314  if (parser.parseRSquare())
3315  return failure();
3316  needCommaBeforeOperands = true;
3317  }
3318 
3319  if (needCommaBeforeOperands && failed(parser.parseComma()))
3320  return failure();
3321 
3322  if (failed(parser.parseCommaSeparatedList([&]() {
3323  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3324  parser.parseColon() ||
3325  parser.parseAttribute(gangDimAttrs.emplace_back()))
3326  return failure();
3327  if (succeeded(parser.parseOptionalLSquare())) {
3328  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3329  parser.parseRSquare())
3330  return failure();
3331  } else {
3332  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3333  parser.getContext(), mlir::acc::DeviceType::None));
3334  }
3335  return success();
3336  })))
3337  return failure();
3338 
3339  if (failed(parser.parseRParen()))
3340  return failure();
3341 
3342  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3343  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
3344  gangDimDeviceTypes =
3345  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
3346 
3347  return success();
3348 }
3349 
3351  std::optional<mlir::ArrayAttr> gang,
3352  std::optional<mlir::ArrayAttr> gangDim,
3353  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3354 
3355  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
3356  gang->size() == 1) {
3357  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3358  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3359  return;
3360  }
3361 
3362  p << "(";
3363 
3364  printDeviceTypes(p, gang);
3365 
3366  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
3367  p << ", ";
3368 
3369  if (hasDeviceTypeValues(gangDimDeviceTypes))
3370  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3371  [&](const auto &pair) {
3372  p << acc::RoutineOp::getGangDimKeyword() << ": ";
3373  p << std::get<0>(pair);
3374  printSingleDeviceType(p, std::get<1>(pair));
3375  });
3376 
3377  p << ")";
3378 }
3379 
3380 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
3381  mlir::ArrayAttr &deviceTypes) {
3383  // Keyword only
3384  if (failed(parser.parseOptionalLParen())) {
3385  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
3387  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3388  return success();
3389  }
3390 
3391  // Parse device type attributes
3392  if (succeeded(parser.parseOptionalLSquare())) {
3393  if (failed(parser.parseCommaSeparatedList([&]() {
3394  if (parser.parseAttribute(attributes.emplace_back()))
3395  return failure();
3396  return success();
3397  })))
3398  return failure();
3399  if (parser.parseRSquare() || parser.parseRParen())
3400  return failure();
3401  }
3402  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3403  return success();
3404 }
3405 
3406 static void
3408  std::optional<mlir::ArrayAttr> deviceTypes) {
3409 
3410  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
3411  auto deviceTypeAttr =
3412  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3413  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3414  return;
3415  }
3416 
3417  if (!hasDeviceTypeValues(deviceTypes))
3418  return;
3419 
3420  p << "([";
3421  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
3422  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3423  p << dTypeAttr;
3424  });
3425  p << "])";
3426 }
3427 
3428 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3429 
3430 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3431  return hasDeviceType(getWorker(), deviceType);
3432 }
3433 
3434 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3435 
3436 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3437  return hasDeviceType(getVector(), deviceType);
3438 }
3439 
3440 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3441 
3442 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3443  return hasDeviceType(getSeq(), deviceType);
3444 }
3445 
3446 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3447  return getBindNameValue(mlir::acc::DeviceType::None);
3448 }
3449 
3450 std::optional<llvm::StringRef>
3451 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3452  if (!hasDeviceTypeValues(getBindNameDeviceType()))
3453  return std::nullopt;
3454  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3455  auto attr = (*getBindName())[*pos];
3456  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3457  return stringAttr.getValue();
3458  }
3459  return std::nullopt;
3460 }
3461 
3462 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3463 
3464 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3465  return hasDeviceType(getGang(), deviceType);
3466 }
3467 
3468 std::optional<int64_t> RoutineOp::getGangDimValue() {
3469  return getGangDimValue(mlir::acc::DeviceType::None);
3470 }
3471 
3472 std::optional<int64_t>
3473 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3474  if (!hasDeviceTypeValues(getGangDimDeviceType()))
3475  return std::nullopt;
3476  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
3477  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3478  return intAttr.getInt();
3479  }
3480  return std::nullopt;
3481 }
3482 
3483 //===----------------------------------------------------------------------===//
3484 // InitOp
3485 //===----------------------------------------------------------------------===//
3486 
3487 LogicalResult acc::InitOp::verify() {
3488  Operation *currOp = *this;
3489  while ((currOp = currOp->getParentOp()))
3490  if (isComputeOperation(currOp))
3491  return emitOpError("cannot be nested in a compute operation");
3492  return success();
3493 }
3494 
3495 void acc::InitOp::addDeviceType(MLIRContext *context,
3496  mlir::acc::DeviceType deviceType) {
3498  if (getDeviceTypesAttr())
3499  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3500 
3501  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3502  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3503 }
3504 
3505 //===----------------------------------------------------------------------===//
3506 // ShutdownOp
3507 //===----------------------------------------------------------------------===//
3508 
3509 LogicalResult acc::ShutdownOp::verify() {
3510  Operation *currOp = *this;
3511  while ((currOp = currOp->getParentOp()))
3512  if (isComputeOperation(currOp))
3513  return emitOpError("cannot be nested in a compute operation");
3514  return success();
3515 }
3516 
3517 void acc::ShutdownOp::addDeviceType(MLIRContext *context,
3518  mlir::acc::DeviceType deviceType) {
3520  if (getDeviceTypesAttr())
3521  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3522 
3523  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3524  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3525 }
3526 
3527 //===----------------------------------------------------------------------===//
3528 // SetOp
3529 //===----------------------------------------------------------------------===//
3530 
3531 LogicalResult acc::SetOp::verify() {
3532  Operation *currOp = *this;
3533  while ((currOp = currOp->getParentOp()))
3534  if (isComputeOperation(currOp))
3535  return emitOpError("cannot be nested in a compute operation");
3536  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3537  return emitOpError("at least one default_async, device_num, or device_type "
3538  "operand must appear");
3539  return success();
3540 }
3541 
3542 //===----------------------------------------------------------------------===//
3543 // UpdateOp
3544 //===----------------------------------------------------------------------===//
3545 
3546 LogicalResult acc::UpdateOp::verify() {
3547  // At least one of host or device should have a value.
3548  if (getDataClauseOperands().empty())
3549  return emitError("at least one value must be present in dataOperands");
3550 
3551  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
3552  getAsyncOperandsDeviceTypeAttr(),
3553  "async")))
3554  return failure();
3555 
3557  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3558  getWaitOperandsDeviceTypeAttr(), "wait")))
3559  return failure();
3560 
3561  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
3562  return failure();
3563 
3564  for (mlir::Value operand : getDataClauseOperands())
3565  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3566  operand.getDefiningOp()))
3567  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3568  "as defining op");
3569 
3570  return success();
3571 }
3572 
3573 unsigned UpdateOp::getNumDataOperands() {
3574  return getDataClauseOperands().size();
3575 }
3576 
3577 Value UpdateOp::getDataOperand(unsigned i) {
3578  unsigned numOptional = getAsyncOperands().size();
3579  numOptional += getIfCond() ? 1 : 0;
3580  return getOperand(getWaitOperands().size() + numOptional + i);
3581 }
3582 
3583 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3584  MLIRContext *context) {
3585  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
3586 }
3587 
3588 bool UpdateOp::hasAsyncOnly() {
3589  return hasAsyncOnly(mlir::acc::DeviceType::None);
3590 }
3591 
3592 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3593  return hasDeviceType(getAsyncOnly(), deviceType);
3594 }
3595 
3596 mlir::Value UpdateOp::getAsyncValue() {
3597  return getAsyncValue(mlir::acc::DeviceType::None);
3598 }
3599 
3600 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3602  return {};
3603 
3604  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3605  return getAsyncOperands()[*pos];
3606 
3607  return {};
3608 }
3609 
3610 bool UpdateOp::hasWaitOnly() {
3611  return hasWaitOnly(mlir::acc::DeviceType::None);
3612 }
3613 
3614 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3615  return hasDeviceType(getWaitOnly(), deviceType);
3616 }
3617 
3618 mlir::Operation::operand_range UpdateOp::getWaitValues() {
3619  return getWaitValues(mlir::acc::DeviceType::None);
3620 }
3621 
3623 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3625  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3626  getHasWaitDevnum(), deviceType);
3627 }
3628 
3629 mlir::Value UpdateOp::getWaitDevnum() {
3630  return getWaitDevnum(mlir::acc::DeviceType::None);
3631 }
3632 
3633 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3634  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3635  getWaitOperandsSegments(), getHasWaitDevnum(),
3636  deviceType);
3637 }
3638 
3639 //===----------------------------------------------------------------------===//
3640 // WaitOp
3641 //===----------------------------------------------------------------------===//
3642 
3643 LogicalResult acc::WaitOp::verify() {
3644  // The async attribute represent the async clause without value. Therefore the
3645  // attribute and operand cannot appear at the same time.
3646  if (getAsyncOperand() && getAsync())
3647  return emitError("async attribute cannot appear with asyncOperand");
3648 
3649  if (getWaitDevnum() && getWaitOperands().empty())
3650  return emitError("wait_devnum cannot appear without waitOperands");
3651 
3652  return success();
3653 }
3654 
3655 #define GET_OP_CLASSES
3656 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3657 
3658 #define GET_ATTRDEF_CLASSES
3659 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3660 
3661 #define GET_TYPEDEF_CLASSES
3662 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3663 
3664 //===----------------------------------------------------------------------===//
3665 // acc dialect utilities
3666 //===----------------------------------------------------------------------===//
3667 
3670  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3672  accDataClauseOp)
3673  .Case<ACC_DATA_ENTRY_OPS>(
3674  [&](auto entry) { return entry.getVarPtr(); })
3675  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3676  [&](auto exit) { return exit.getVarPtr(); })
3677  .Default([&](mlir::Operation *) {
3679  })};
3680  return varPtr;
3681 }
3682 
3684  auto varPtr{
3686  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3687  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3688  return varPtr;
3689 }
3690 
3692  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3693  .Case<ACC_DATA_ENTRY_OPS>(
3694  [&](auto entry) { return entry.getVarType(); })
3695  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3696  [&](auto exit) { return exit.getVarType(); })
3697  .Default([&](mlir::Operation *) { return mlir::Type(); })};
3698  return varType;
3699 }
3700 
3703  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3705  accDataClauseOp)
3706  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3707  [&](auto dataClause) { return dataClause.getAccPtr(); })
3708  .Default([&](mlir::Operation *) {
3710  })};
3711  return accPtr;
3712 }
3713 
3715  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3717  [&](auto dataClause) { return dataClause.getAccVar(); })
3718  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3719  return accPtr;
3720 }
3721 
3723  auto varPtrPtr{
3725  .Case<ACC_DATA_ENTRY_OPS>(
3726  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3727  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3728  return varPtrPtr;
3729 }
3730 
3735  accDataClauseOp)
3736  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3738  dataClause.getBounds().begin(), dataClause.getBounds().end());
3739  })
3740  .Default([&](mlir::Operation *) {
3742  })};
3743  return bounds;
3744 }
3745 
3749  accDataClauseOp)
3750  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3752  dataClause.getAsyncOperands().begin(),
3753  dataClause.getAsyncOperands().end());
3754  })
3755  .Default([&](mlir::Operation *) {
3757  });
3758 }
3759 
3760 mlir::ArrayAttr
3763  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3764  return dataClause.getAsyncOperandsDeviceTypeAttr();
3765  })
3766  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3767 }
3768 
3769 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3772  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3773  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3774 }
3775 
3776 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3777  auto name{
3779  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3780  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3781  return {};
3782  })};
3783  return name;
3784 }
3785 
3786 std::optional<mlir::acc::DataClause>
3788  auto dataClause{
3790  accDataEntryOp)
3791  .Case<ACC_DATA_ENTRY_OPS>(
3792  [&](auto entry) { return entry.getDataClause(); })
3793  .Default([&](mlir::Operation *) { return std::nullopt; })};
3794  return dataClause;
3795 }
3796 
3798  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3799  .Case<ACC_DATA_ENTRY_OPS>(
3800  [&](auto entry) { return entry.getImplicit(); })
3801  .Default([&](mlir::Operation *) { return false; })};
3802  return implicit;
3803 }
3804 
3806  auto dataOperands{
3809  [&](auto entry) { return entry.getDataClauseOperands(); })
3810  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3811  return dataOperands;
3812 }
3813 
3816  auto dataOperands{
3819  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3820  .Default([&](mlir::Operation *) { return nullptr; })};
3821  return dataOperands;
3822 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:114
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2253
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:3350
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:742
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:2356
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:1044
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2370
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:756
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1521
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:370
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:3253
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:339
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1532
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1437
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:174
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:3407
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:2165
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1772
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:3162
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:317
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1726
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:194
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:292
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2692
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:973
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1568
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1126
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:348
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:218
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1307
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:324
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2723
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:3380
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:3289
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1420
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1595
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:3279
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1711
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1374
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:957
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:250
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1687
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:411
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:2184
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:832
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1756
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1351
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:205
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:988
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1668
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:180
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:2311
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:234
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1606
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:382
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:270
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:1054
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:3220
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1357
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1792
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:937
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:44
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:52
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:204
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:834
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:3714
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3683
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3702
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:3787
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:3815
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:3732
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:3805
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3776
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:3797
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:3747
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:3722
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3769
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:3691
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3669
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3761
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.