MLIR  22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 #include <variant>
25 
26 using namespace mlir;
27 using namespace acc;
28 
29 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
33 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
34 
35 namespace {
36 
37 static bool isScalarLikeType(Type type) {
38  return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
39 }
40 
41 struct MemRefPointerLikeModel
42  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
43  MemRefType> {
44  Type getElementType(Type pointer) const {
45  return cast<MemRefType>(pointer).getElementType();
46  }
47  mlir::acc::VariableTypeCategory
48  getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
49  Type varType) const {
50  if (auto mappableTy = dyn_cast<MappableType>(varType)) {
51  return mappableTy.getTypeCategory(varPtr);
52  }
53  auto memrefTy = cast<MemRefType>(pointer);
54  if (!memrefTy.hasRank()) {
55  // This memref is unranked - aka it could have any rank, including a
56  // rank of 0 which could mean scalar. For now, return uncategorized.
57  return mlir::acc::VariableTypeCategory::uncategorized;
58  }
59 
60  if (memrefTy.getRank() == 0) {
61  if (isScalarLikeType(memrefTy.getElementType())) {
62  return mlir::acc::VariableTypeCategory::scalar;
63  }
64  // Zero-rank non-scalar - need further analysis to determine the type
65  // category. For now, return uncategorized.
66  return mlir::acc::VariableTypeCategory::uncategorized;
67  }
68 
69  // It has a rank - must be an array.
70  assert(memrefTy.getRank() > 0 && "rank expected to be positive");
71  return mlir::acc::VariableTypeCategory::array;
72  }
73 };
74 
75 struct LLVMPointerPointerLikeModel
76  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77  LLVM::LLVMPointerType> {
78  Type getElementType(Type pointer) const { return Type(); }
79 };
80 
81 /// Helper function for any of the times we need to modify an ArrayAttr based on
82 /// a device type list. Returns a new ArrayAttr with all of the
83 /// existingDeviceTypes, plus the effective new ones(or an added none if hte new
84 /// list is empty).
85 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
86  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
87  llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
89  if (existingDeviceTypes)
90  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
91 
92  if (newDeviceTypes.empty())
93  deviceTypes.push_back(
95 
96  for (DeviceType DT : newDeviceTypes)
97  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
98 
99  return mlir::ArrayAttr::get(context, deviceTypes);
100 }
101 
102 /// Helper function for any of the times we need to add operands that are
103 /// affected by a device type list. Returns a new ArrayAttr with all of the
104 /// existingDeviceTypes, plus the effective new ones (or an added none, if the
105 /// new list is empty). Additionally, adds the arguments to the argCollection
106 /// the correct number of times. This will also update a 'segments' array, even
107 /// if it won't be used.
108 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
109  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
110  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
111  mlir::MutableOperandRange argCollection,
112  llvm::SmallVector<int32_t> &segments) {
114  if (existingDeviceTypes)
115  llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
116 
117  if (newDeviceTypes.empty()) {
118  argCollection.append(arguments);
119  segments.push_back(arguments.size());
120  deviceTypes.push_back(
122  }
123 
124  for (DeviceType DT : newDeviceTypes) {
125  argCollection.append(arguments);
126  segments.push_back(arguments.size());
127  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
128  }
129 
130  return mlir::ArrayAttr::get(context, deviceTypes);
131 }
132 
133 /// Overload for when the 'segments' aren't needed.
134 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
135  MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
136  llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
137  mlir::MutableOperandRange argCollection) {
139  return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
140  newDeviceTypes, arguments,
141  argCollection, segments);
142 }
143 } // namespace
144 
145 //===----------------------------------------------------------------------===//
146 // OpenACC operations
147 //===----------------------------------------------------------------------===//
148 
149 void OpenACCDialect::initialize() {
150  addOperations<
151 #define GET_OP_LIST
152 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
153  >();
154  addAttributes<
155 #define GET_ATTRDEF_LIST
156 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
157  >();
158  addTypes<
159 #define GET_TYPEDEF_LIST
160 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
161  >();
162 
163  // By attaching interfaces here, we make the OpenACC dialect dependent on
164  // the other dialects. This is probably better than having dialects like LLVM
165  // and memref be dependent on OpenACC.
166  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
167  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
168  *getContext());
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // device_type support helpers
173 //===----------------------------------------------------------------------===//
174 
175 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
176  return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
177 }
178 
179 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
180  mlir::acc::DeviceType deviceType) {
181  if (!hasDeviceTypeValues(arrayAttr))
182  return false;
183 
184  for (auto attr : *arrayAttr) {
185  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
186  if (deviceTypeAttr.getValue() == deviceType)
187  return true;
188  }
189 
190  return false;
191 }
192 
194  std::optional<mlir::ArrayAttr> deviceTypes) {
195  if (!hasDeviceTypeValues(deviceTypes))
196  return;
197 
198  p << "[";
199  llvm::interleaveComma(*deviceTypes, p,
200  [&](mlir::Attribute attr) { p << attr; });
201  p << "]";
202 }
203 
204 static std::optional<unsigned> findSegment(ArrayAttr segments,
205  mlir::acc::DeviceType deviceType) {
206  unsigned segmentIdx = 0;
207  for (auto attr : segments) {
208  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
209  if (deviceTypeAttr.getValue() == deviceType)
210  return std::make_optional(segmentIdx);
211  ++segmentIdx;
212  }
213  return std::nullopt;
214 }
215 
217 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
219  std::optional<llvm::ArrayRef<int32_t>> segments,
220  mlir::acc::DeviceType deviceType) {
221  if (!arrayAttr)
222  return range.take_front(0);
223  if (auto pos = findSegment(*arrayAttr, deviceType)) {
224  int32_t nbOperandsBefore = 0;
225  for (unsigned i = 0; i < *pos; ++i)
226  nbOperandsBefore += (*segments)[i];
227  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
228  }
229  return range.take_front(0);
230 }
231 
232 static mlir::Value
233 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
235  std::optional<llvm::ArrayRef<int32_t>> segments,
236  std::optional<mlir::ArrayAttr> hasWaitDevnum,
237  mlir::acc::DeviceType deviceType) {
238  if (!hasDeviceTypeValues(deviceTypeAttr))
239  return {};
240  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
241  if (hasWaitDevnum->getValue()[*pos])
242  return getValuesFromSegments(deviceTypeAttr, operands, segments,
243  deviceType)
244  .front();
245  return {};
246 }
247 
249 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
251  std::optional<llvm::ArrayRef<int32_t>> segments,
252  std::optional<mlir::ArrayAttr> hasWaitDevnum,
253  mlir::acc::DeviceType deviceType) {
254  auto range =
255  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
256  if (range.empty())
257  return range;
258  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
259  if (hasWaitDevnum && *hasWaitDevnum) {
260  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
261  if (boolAttr.getValue())
262  return range.drop_front(1); // first value is devnum
263  }
264  }
265  return range;
266 }
267 
268 template <typename Op>
269 static LogicalResult checkWaitAndAsyncConflict(Op op) {
270  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
271  ++dtypeInt) {
272  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
273 
274  // The asyncOnly attribute represent the async clause without value.
275  // Therefore the attribute and operand cannot appear at the same time.
276  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
277  op.hasAsyncOnly(dtype))
278  return op.emitError(
279  "asyncOnly attribute cannot appear with asyncOperand");
280 
281  // The wait attribute represent the wait clause without values. Therefore
282  // the attribute and operands cannot appear at the same time.
283  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
284  op.hasWaitOnly(dtype))
285  return op.emitError("wait attribute cannot appear with waitOperands");
286  }
287  return success();
288 }
289 
290 template <typename Op>
291 static LogicalResult checkVarAndVarType(Op op) {
292  if (!op.getVar())
293  return op.emitError("must have var operand");
294 
295  // A variable must have a type that is either pointer-like or mappable.
296  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
298  return op.emitError("var must be mappable or pointer-like");
299 
300  // When it is a pointer-like type, the varType must capture the target type.
301  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
302  op.getVarType() == op.getVar().getType())
303  return op.emitError("varType must capture the element type of var");
304 
305  return success();
306 }
307 
308 template <typename Op>
309 static LogicalResult checkVarAndAccVar(Op op) {
310  if (op.getVar().getType() != op.getAccVar().getType())
311  return op.emitError("input and output types must match");
312 
313  return success();
314 }
315 
316 template <typename Op>
317 static LogicalResult checkNoModifier(Op op) {
318  if (op.getModifiers() != acc::DataClauseModifier::none)
319  return op.emitError("no data clause modifiers are allowed");
320  return success();
321 }
322 
323 template <typename Op>
324 static LogicalResult
325 checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
326  if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
327  return op.emitError(
328  "invalid data clause modifiers: " +
329  acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
330 
331  return success();
332 }
333 
334 static ParseResult parseVar(mlir::OpAsmParser &parser,
336  // Either `var` or `varPtr` keyword is required.
337  if (failed(parser.parseOptionalKeyword("varPtr"))) {
338  if (failed(parser.parseKeyword("var")))
339  return failure();
340  }
341  if (failed(parser.parseLParen()))
342  return failure();
343  if (failed(parser.parseOperand(var)))
344  return failure();
345 
346  return success();
347 }
348 
350  mlir::Value var) {
351  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
352  p << "varPtr(";
353  else
354  p << "var(";
355  p.printOperand(var);
356 }
357 
358 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
360  mlir::Type &accVarType) {
361  // Either `accVar` or `accPtr` keyword is required.
362  if (failed(parser.parseOptionalKeyword("accPtr"))) {
363  if (failed(parser.parseKeyword("accVar")))
364  return failure();
365  }
366  if (failed(parser.parseLParen()))
367  return failure();
368  if (failed(parser.parseOperand(var)))
369  return failure();
370  if (failed(parser.parseColon()))
371  return failure();
372  if (failed(parser.parseType(accVarType)))
373  return failure();
374  if (failed(parser.parseRParen()))
375  return failure();
376 
377  return success();
378 }
379 
381  mlir::Value accVar, mlir::Type accVarType) {
382  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
383  p << "accPtr(";
384  else
385  p << "accVar(";
386  p.printOperand(accVar);
387  p << " : ";
388  p.printType(accVarType);
389  p << ")";
390 }
391 
392 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
393  mlir::Type &varPtrType,
394  mlir::TypeAttr &varTypeAttr) {
395  if (failed(parser.parseType(varPtrType)))
396  return failure();
397  if (failed(parser.parseRParen()))
398  return failure();
399 
400  if (succeeded(parser.parseOptionalKeyword("varType"))) {
401  if (failed(parser.parseLParen()))
402  return failure();
403  mlir::Type varType;
404  if (failed(parser.parseType(varType)))
405  return failure();
406  varTypeAttr = mlir::TypeAttr::get(varType);
407  if (failed(parser.parseRParen()))
408  return failure();
409  } else {
410  // Set `varType` from the element type of the type of `varPtr`.
411  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
412  varTypeAttr = mlir::TypeAttr::get(
413  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
414  else
415  varTypeAttr = mlir::TypeAttr::get(varPtrType);
416  }
417 
418  return success();
419 }
420 
422  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
423  p.printType(varPtrType);
424  p << ")";
425 
426  // Print the `varType` only if it differs from the element type of
427  // `varPtr`'s type.
428  mlir::Type varType = varTypeAttr.getValue();
429  mlir::Type typeToCheckAgainst =
430  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
431  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
432  : varPtrType;
433  if (typeToCheckAgainst != varType) {
434  p << " varType(";
435  p.printType(varType);
436  p << ")";
437  }
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // DataBoundsOp
442 //===----------------------------------------------------------------------===//
443 LogicalResult acc::DataBoundsOp::verify() {
444  auto extent = getExtent();
445  auto upperbound = getUpperbound();
446  if (!extent && !upperbound)
447  return emitError("expected extent or upperbound.");
448  return success();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // PrivateOp
453 //===----------------------------------------------------------------------===//
454 LogicalResult acc::PrivateOp::verify() {
455  if (getDataClause() != acc::DataClause::acc_private)
456  return emitError(
457  "data clause associated with private operation must match its intent");
458  if (failed(checkVarAndVarType(*this)))
459  return failure();
460  if (failed(checkNoModifier(*this)))
461  return failure();
462  return success();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // FirstprivateOp
467 //===----------------------------------------------------------------------===//
468 LogicalResult acc::FirstprivateOp::verify() {
469  if (getDataClause() != acc::DataClause::acc_firstprivate)
470  return emitError("data clause associated with firstprivate operation must "
471  "match its intent");
472  if (failed(checkVarAndVarType(*this)))
473  return failure();
474  if (failed(checkNoModifier(*this)))
475  return failure();
476  return success();
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // ReductionOp
481 //===----------------------------------------------------------------------===//
482 LogicalResult acc::ReductionOp::verify() {
483  if (getDataClause() != acc::DataClause::acc_reduction)
484  return emitError("data clause associated with reduction operation must "
485  "match its intent");
486  if (failed(checkVarAndVarType(*this)))
487  return failure();
488  if (failed(checkNoModifier(*this)))
489  return failure();
490  return success();
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // DevicePtrOp
495 //===----------------------------------------------------------------------===//
496 LogicalResult acc::DevicePtrOp::verify() {
497  if (getDataClause() != acc::DataClause::acc_deviceptr)
498  return emitError("data clause associated with deviceptr operation must "
499  "match its intent");
500  if (failed(checkVarAndVarType(*this)))
501  return failure();
502  if (failed(checkVarAndAccVar(*this)))
503  return failure();
504  if (failed(checkNoModifier(*this)))
505  return failure();
506  return success();
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // PresentOp
511 //===----------------------------------------------------------------------===//
512 LogicalResult acc::PresentOp::verify() {
513  if (getDataClause() != acc::DataClause::acc_present)
514  return emitError(
515  "data clause associated with present operation must match its intent");
516  if (failed(checkVarAndVarType(*this)))
517  return failure();
518  if (failed(checkVarAndAccVar(*this)))
519  return failure();
520  if (failed(checkNoModifier(*this)))
521  return failure();
522  return success();
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // CopyinOp
527 //===----------------------------------------------------------------------===//
528 LogicalResult acc::CopyinOp::verify() {
529  // Test for all clauses this operation can be decomposed from:
530  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
531  getDataClause() != acc::DataClause::acc_copyin_readonly &&
532  getDataClause() != acc::DataClause::acc_copy &&
533  getDataClause() != acc::DataClause::acc_reduction)
534  return emitError(
535  "data clause associated with copyin operation must match its intent"
536  " or specify original clause this operation was decomposed from");
537  if (failed(checkVarAndVarType(*this)))
538  return failure();
539  if (failed(checkVarAndAccVar(*this)))
540  return failure();
541  if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
542  acc::DataClauseModifier::always |
543  acc::DataClauseModifier::capture)))
544  return failure();
545  return success();
546 }
547 
548 bool acc::CopyinOp::isCopyinReadonly() {
549  return getDataClause() == acc::DataClause::acc_copyin_readonly ||
550  acc::bitEnumContainsAny(getModifiers(),
551  acc::DataClauseModifier::readonly);
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // CreateOp
556 //===----------------------------------------------------------------------===//
557 LogicalResult acc::CreateOp::verify() {
558  // Test for all clauses this operation can be decomposed from:
559  if (getDataClause() != acc::DataClause::acc_create &&
560  getDataClause() != acc::DataClause::acc_create_zero &&
561  getDataClause() != acc::DataClause::acc_copyout &&
562  getDataClause() != acc::DataClause::acc_copyout_zero)
563  return emitError(
564  "data clause associated with create operation must match its intent"
565  " or specify original clause this operation was decomposed from");
566  if (failed(checkVarAndVarType(*this)))
567  return failure();
568  if (failed(checkVarAndAccVar(*this)))
569  return failure();
570  // this op is the entry part of copyout, so it also needs to allow all
571  // modifiers allowed on copyout.
572  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
573  acc::DataClauseModifier::always |
574  acc::DataClauseModifier::capture)))
575  return failure();
576  return success();
577 }
578 
579 bool acc::CreateOp::isCreateZero() {
580  // The zero modifier is encoded in the data clause.
581  return getDataClause() == acc::DataClause::acc_create_zero ||
582  getDataClause() == acc::DataClause::acc_copyout_zero ||
583  acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // NoCreateOp
588 //===----------------------------------------------------------------------===//
589 LogicalResult acc::NoCreateOp::verify() {
590  if (getDataClause() != acc::DataClause::acc_no_create)
591  return emitError("data clause associated with no_create operation must "
592  "match its intent");
593  if (failed(checkVarAndVarType(*this)))
594  return failure();
595  if (failed(checkVarAndAccVar(*this)))
596  return failure();
597  if (failed(checkNoModifier(*this)))
598  return failure();
599  return success();
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // AttachOp
604 //===----------------------------------------------------------------------===//
605 LogicalResult acc::AttachOp::verify() {
606  if (getDataClause() != acc::DataClause::acc_attach)
607  return emitError(
608  "data clause associated with attach operation must match its intent");
609  if (failed(checkVarAndVarType(*this)))
610  return failure();
611  if (failed(checkVarAndAccVar(*this)))
612  return failure();
613  if (failed(checkNoModifier(*this)))
614  return failure();
615  return success();
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // DeclareDeviceResidentOp
620 //===----------------------------------------------------------------------===//
621 
622 LogicalResult acc::DeclareDeviceResidentOp::verify() {
623  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
624  return emitError("data clause associated with device_resident operation "
625  "must match its intent");
626  if (failed(checkVarAndVarType(*this)))
627  return failure();
628  if (failed(checkVarAndAccVar(*this)))
629  return failure();
630  if (failed(checkNoModifier(*this)))
631  return failure();
632  return success();
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // DeclareLinkOp
637 //===----------------------------------------------------------------------===//
638 
639 LogicalResult acc::DeclareLinkOp::verify() {
640  if (getDataClause() != acc::DataClause::acc_declare_link)
641  return emitError(
642  "data clause associated with link operation must match its intent");
643  if (failed(checkVarAndVarType(*this)))
644  return failure();
645  if (failed(checkVarAndAccVar(*this)))
646  return failure();
647  if (failed(checkNoModifier(*this)))
648  return failure();
649  return success();
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // CopyoutOp
654 //===----------------------------------------------------------------------===//
655 LogicalResult acc::CopyoutOp::verify() {
656  // Test for all clauses this operation can be decomposed from:
657  if (getDataClause() != acc::DataClause::acc_copyout &&
658  getDataClause() != acc::DataClause::acc_copyout_zero &&
659  getDataClause() != acc::DataClause::acc_copy &&
660  getDataClause() != acc::DataClause::acc_reduction)
661  return emitError(
662  "data clause associated with copyout operation must match its intent"
663  " or specify original clause this operation was decomposed from");
664  if (!getVar() || !getAccVar())
665  return emitError("must have both host and device pointers");
666  if (failed(checkVarAndVarType(*this)))
667  return failure();
668  if (failed(checkVarAndAccVar(*this)))
669  return failure();
670  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
671  acc::DataClauseModifier::always |
672  acc::DataClauseModifier::capture)))
673  return failure();
674  return success();
675 }
676 
677 bool acc::CopyoutOp::isCopyoutZero() {
678  return getDataClause() == acc::DataClause::acc_copyout_zero ||
679  acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // DeleteOp
684 //===----------------------------------------------------------------------===//
685 LogicalResult acc::DeleteOp::verify() {
686  // Test for all clauses this operation can be decomposed from:
687  if (getDataClause() != acc::DataClause::acc_delete &&
688  getDataClause() != acc::DataClause::acc_create &&
689  getDataClause() != acc::DataClause::acc_create_zero &&
690  getDataClause() != acc::DataClause::acc_copyin &&
691  getDataClause() != acc::DataClause::acc_copyin_readonly &&
692  getDataClause() != acc::DataClause::acc_present &&
693  getDataClause() != acc::DataClause::acc_no_create &&
694  getDataClause() != acc::DataClause::acc_declare_device_resident &&
695  getDataClause() != acc::DataClause::acc_declare_link)
696  return emitError(
697  "data clause associated with delete operation must match its intent"
698  " or specify original clause this operation was decomposed from");
699  if (!getAccVar())
700  return emitError("must have device pointer");
701  // This op is the exit part of copyin and create - thus allow all modifiers
702  // allowed on either case.
703  if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
704  acc::DataClauseModifier::readonly |
705  acc::DataClauseModifier::always |
706  acc::DataClauseModifier::capture)))
707  return failure();
708  return success();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // DetachOp
713 //===----------------------------------------------------------------------===//
714 LogicalResult acc::DetachOp::verify() {
715  // Test for all clauses this operation can be decomposed from:
716  if (getDataClause() != acc::DataClause::acc_detach &&
717  getDataClause() != acc::DataClause::acc_attach)
718  return emitError(
719  "data clause associated with detach operation must match its intent"
720  " or specify original clause this operation was decomposed from");
721  if (!getAccVar())
722  return emitError("must have device pointer");
723  if (failed(checkNoModifier(*this)))
724  return failure();
725  return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // HostOp
730 //===----------------------------------------------------------------------===//
731 LogicalResult acc::UpdateHostOp::verify() {
732  // Test for all clauses this operation can be decomposed from:
733  if (getDataClause() != acc::DataClause::acc_update_host &&
734  getDataClause() != acc::DataClause::acc_update_self)
735  return emitError(
736  "data clause associated with host operation must match its intent"
737  " or specify original clause this operation was decomposed from");
738  if (!getVar() || !getAccVar())
739  return emitError("must have both host and device pointers");
740  if (failed(checkVarAndVarType(*this)))
741  return failure();
742  if (failed(checkVarAndAccVar(*this)))
743  return failure();
744  if (failed(checkNoModifier(*this)))
745  return failure();
746  return success();
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // DeviceOp
751 //===----------------------------------------------------------------------===//
752 LogicalResult acc::UpdateDeviceOp::verify() {
753  // Test for all clauses this operation can be decomposed from:
754  if (getDataClause() != acc::DataClause::acc_update_device)
755  return emitError(
756  "data clause associated with device operation must match its intent"
757  " or specify original clause this operation was decomposed from");
758  if (failed(checkVarAndVarType(*this)))
759  return failure();
760  if (failed(checkVarAndAccVar(*this)))
761  return failure();
762  if (failed(checkNoModifier(*this)))
763  return failure();
764  return success();
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // UseDeviceOp
769 //===----------------------------------------------------------------------===//
770 LogicalResult acc::UseDeviceOp::verify() {
771  // Test for all clauses this operation can be decomposed from:
772  if (getDataClause() != acc::DataClause::acc_use_device)
773  return emitError(
774  "data clause associated with use_device operation must match its intent"
775  " or specify original clause this operation was decomposed from");
776  if (failed(checkVarAndVarType(*this)))
777  return failure();
778  if (failed(checkVarAndAccVar(*this)))
779  return failure();
780  if (failed(checkNoModifier(*this)))
781  return failure();
782  return success();
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // CacheOp
787 //===----------------------------------------------------------------------===//
788 LogicalResult acc::CacheOp::verify() {
789  // Test for all clauses this operation can be decomposed from:
790  if (getDataClause() != acc::DataClause::acc_cache &&
791  getDataClause() != acc::DataClause::acc_cache_readonly)
792  return emitError(
793  "data clause associated with cache operation must match its intent"
794  " or specify original clause this operation was decomposed from");
795  if (failed(checkVarAndVarType(*this)))
796  return failure();
797  if (failed(checkVarAndAccVar(*this)))
798  return failure();
799  if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
800  return failure();
801  return success();
802 }
803 
804 bool acc::CacheOp::isCacheReadonly() {
805  return getDataClause() == acc::DataClause::acc_cache_readonly ||
806  acc::bitEnumContainsAny(getModifiers(),
807  acc::DataClauseModifier::readonly);
808 }
809 
810 template <typename StructureOp>
811 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
812  unsigned nRegions = 1) {
813 
814  SmallVector<Region *, 2> regions;
815  for (unsigned i = 0; i < nRegions; ++i)
816  regions.push_back(state.addRegion());
817 
818  for (Region *region : regions)
819  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
820  return failure();
821 
822  return success();
823 }
824 
825 static bool isComputeOperation(Operation *op) {
826  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
827 }
828 
829 namespace {
830 /// Pattern to remove operation without region that have constant false `ifCond`
831 /// and remove the condition from the operation if the `ifCond` is a true
832 /// constant.
833 template <typename OpTy>
834 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
836 
837  LogicalResult matchAndRewrite(OpTy op,
838  PatternRewriter &rewriter) const override {
839  // Early return if there is no condition.
840  Value ifCond = op.getIfCond();
841  if (!ifCond)
842  return failure();
843 
844  IntegerAttr constAttr;
845  if (!matchPattern(ifCond, m_Constant(&constAttr)))
846  return failure();
847  if (constAttr.getInt())
848  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
849  else
850  rewriter.eraseOp(op);
851 
852  return success();
853  }
854 };
855 
856 /// Replaces the given op with the contents of the given single-block region,
857 /// using the operands of the block terminator to replace operation results.
858 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
859  Region &region, ValueRange blockArgs = {}) {
860  assert(region.hasOneBlock() && "expected single-block region");
861  Block *block = &region.front();
862  Operation *terminator = block->getTerminator();
863  ValueRange results = terminator->getOperands();
864  rewriter.inlineBlockBefore(block, op, blockArgs);
865  rewriter.replaceOp(op, results);
866  rewriter.eraseOp(terminator);
867 }
868 
869 /// Pattern to remove operation with region that have constant false `ifCond`
870 /// and remove the condition from the operation if the `ifCond` is constant
871 /// true.
872 template <typename OpTy>
873 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
875 
876  LogicalResult matchAndRewrite(OpTy op,
877  PatternRewriter &rewriter) const override {
878  // Early return if there is no condition.
879  Value ifCond = op.getIfCond();
880  if (!ifCond)
881  return failure();
882 
883  IntegerAttr constAttr;
884  if (!matchPattern(ifCond, m_Constant(&constAttr)))
885  return failure();
886  if (constAttr.getInt())
887  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
888  else
889  replaceOpWithRegion(rewriter, op, op.getRegion());
890 
891  return success();
892  }
893 };
894 
895 } // namespace
896 
897 //===----------------------------------------------------------------------===//
898 // PrivateRecipeOp
899 //===----------------------------------------------------------------------===//
900 
901 static LogicalResult verifyInitLikeSingleArgRegion(
902  Operation *op, Region &region, StringRef regionType, StringRef regionName,
903  Type type, bool verifyYield, bool optional = false) {
904  if (optional && region.empty())
905  return success();
906 
907  if (region.empty())
908  return op->emitOpError() << "expects non-empty " << regionName << " region";
909  Block &firstBlock = region.front();
910  if (firstBlock.getNumArguments() < 1 ||
911  firstBlock.getArgument(0).getType() != type)
912  return op->emitOpError() << "expects " << regionName
913  << " region first "
914  "argument of the "
915  << regionType << " type";
916 
917  if (verifyYield) {
918  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
919  if (yieldOp.getOperands().size() != 1 ||
920  yieldOp.getOperands().getTypes()[0] != type)
921  return op->emitOpError() << "expects " << regionName
922  << " region to "
923  "yield a value of the "
924  << regionType << " type";
925  }
926  }
927  return success();
928 }
929 
930 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
931  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
932  "privatization", "init", getType(),
933  /*verifyYield=*/false)))
934  return failure();
936  *this, getDestroyRegion(), "privatization", "destroy", getType(),
937  /*verifyYield=*/false, /*optional=*/true)))
938  return failure();
939  return success();
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // FirstprivateRecipeOp
944 //===----------------------------------------------------------------------===//
945 
946 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
947  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
948  "privatization", "init", getType(),
949  /*verifyYield=*/false)))
950  return failure();
951 
952  if (getCopyRegion().empty())
953  return emitOpError() << "expects non-empty copy region";
954 
955  Block &firstBlock = getCopyRegion().front();
956  if (firstBlock.getNumArguments() < 2 ||
957  firstBlock.getArgument(0).getType() != getType())
958  return emitOpError() << "expects copy region with two arguments of the "
959  "privatization type";
960 
961  if (getDestroyRegion().empty())
962  return success();
963 
964  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
965  "privatization", "destroy",
966  getType(), /*verifyYield=*/false)))
967  return failure();
968 
969  return success();
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // ReductionRecipeOp
974 //===----------------------------------------------------------------------===//
975 
976 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
977  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
978  "init", getType(),
979  /*verifyYield=*/false)))
980  return failure();
981 
982  if (getCombinerRegion().empty())
983  return emitOpError() << "expects non-empty combiner region";
984 
985  Block &reductionBlock = getCombinerRegion().front();
986  if (reductionBlock.getNumArguments() < 2 ||
987  reductionBlock.getArgument(0).getType() != getType() ||
988  reductionBlock.getArgument(1).getType() != getType())
989  return emitOpError() << "expects combiner region with the first two "
990  << "arguments of the reduction type";
991 
992  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
993  if (yieldOp.getOperands().size() != 1 ||
994  yieldOp.getOperands().getTypes()[0] != getType())
995  return emitOpError() << "expects combiner region to yield a value "
996  "of the reduction type";
997  }
998 
999  return success();
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // Custom parser and printer verifier for private clause
1004 //===----------------------------------------------------------------------===//
1005 
1006 static ParseResult parseSymOperandList(
1007  mlir::OpAsmParser &parser,
1009  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
1011  if (failed(parser.parseCommaSeparatedList([&]() {
1012  if (parser.parseAttribute(attributes.emplace_back()) ||
1013  parser.parseArrow() ||
1014  parser.parseOperand(operands.emplace_back()) ||
1015  parser.parseColonType(types.emplace_back()))
1016  return failure();
1017  return success();
1018  })))
1019  return failure();
1020  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1021  attributes.end());
1022  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
1023  return success();
1024 }
1025 
1027  mlir::OperandRange operands,
1028  mlir::TypeRange types,
1029  std::optional<mlir::ArrayAttr> attributes) {
1030  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
1031  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
1032  << std::get<1>(it).getType();
1033  });
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // ParallelOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1041 template <typename Op>
1042 static LogicalResult checkDataOperands(Op op,
1043  const mlir::ValueRange &operands) {
1044  for (mlir::Value operand : operands)
1045  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1046  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1047  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1048  operand.getDefiningOp()))
1049  return op.emitError(
1050  "expect data entry/exit operation or acc.getdeviceptr "
1051  "as defining op");
1052  return success();
1053 }
1054 
1055 template <typename Op>
1056 static LogicalResult
1057 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
1058  mlir::OperandRange operands, llvm::StringRef operandName,
1059  llvm::StringRef symbolName, bool checkOperandType = true) {
1060  if (!operands.empty()) {
1061  if (!attributes || attributes->size() != operands.size())
1062  return op->emitOpError()
1063  << "expected as many " << symbolName << " symbol reference as "
1064  << operandName << " operands";
1065  } else {
1066  if (attributes)
1067  return op->emitOpError()
1068  << "unexpected " << symbolName << " symbol reference";
1069  return success();
1070  }
1071 
1073  for (auto args : llvm::zip(operands, *attributes)) {
1074  mlir::Value operand = std::get<0>(args);
1075 
1076  if (!set.insert(operand).second)
1077  return op->emitOpError()
1078  << operandName << " operand appears more than once";
1079 
1080  mlir::Type varType = operand.getType();
1081  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1082  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1083  if (!decl)
1084  return op->emitOpError()
1085  << "expected symbol reference " << symbolRef << " to point to a "
1086  << operandName << " declaration";
1087 
1088  if (checkOperandType && decl.getType() && decl.getType() != varType)
1089  return op->emitOpError() << "expected " << operandName << " (" << varType
1090  << ") to be the same type as " << operandName
1091  << " declaration (" << decl.getType() << ")";
1092  }
1093 
1094  return success();
1095 }
1096 
1097 unsigned ParallelOp::getNumDataOperands() {
1098  return getReductionOperands().size() + getPrivateOperands().size() +
1099  getFirstprivateOperands().size() + getDataClauseOperands().size();
1100 }
1101 
1102 Value ParallelOp::getDataOperand(unsigned i) {
1103  unsigned numOptional = getAsyncOperands().size();
1104  numOptional += getNumGangs().size();
1105  numOptional += getNumWorkers().size();
1106  numOptional += getVectorLength().size();
1107  numOptional += getIfCond() ? 1 : 0;
1108  numOptional += getSelfCond() ? 1 : 0;
1109  return getOperand(getWaitOperands().size() + numOptional + i);
1110 }
1111 
1112 template <typename Op>
1113 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1114  ArrayAttr deviceTypes,
1115  llvm::StringRef keyword) {
1116  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1117  return op.emitOpError() << keyword << " operands count must match "
1118  << keyword << " device_type count";
1119  return success();
1120 }
1121 
1122 template <typename Op>
1124  Op op, OperandRange operands, DenseI32ArrayAttr segments,
1125  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1126  std::size_t numOperandsInSegments = 0;
1127  std::size_t nbOfSegments = 0;
1128 
1129  if (segments) {
1130  for (auto segCount : segments.asArrayRef()) {
1131  if (maxInSegment != 0 && segCount > maxInSegment)
1132  return op.emitOpError() << keyword << " expects a maximum of "
1133  << maxInSegment << " values per segment";
1134  numOperandsInSegments += segCount;
1135  ++nbOfSegments;
1136  }
1137  }
1138 
1139  if ((numOperandsInSegments != operands.size()) ||
1140  (!deviceTypes && !operands.empty()))
1141  return op.emitOpError()
1142  << keyword << " operand count does not match count in segments";
1143  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1144  return op.emitOpError()
1145  << keyword << " segment count does not match device_type count";
1146  return success();
1147 }
1148 
1149 LogicalResult acc::ParallelOp::verify() {
1150  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1151  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1152  "privatizations", /*checkOperandType=*/false)))
1153  return failure();
1154  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1155  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1156  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1157  return failure();
1158  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1159  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1160  "reductions", false)))
1161  return failure();
1162 
1164  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1165  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1166  return failure();
1167 
1169  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1170  getWaitOperandsDeviceTypeAttr(), "wait")))
1171  return failure();
1172 
1173  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1174  getNumWorkersDeviceTypeAttr(),
1175  "num_workers")))
1176  return failure();
1177 
1178  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1179  getVectorLengthDeviceTypeAttr(),
1180  "vector_length")))
1181  return failure();
1182 
1184  getAsyncOperandsDeviceTypeAttr(),
1185  "async")))
1186  return failure();
1187 
1188  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1189  return failure();
1190 
1191  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1192 }
1193 
1194 static mlir::Value
1195 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1197  mlir::acc::DeviceType deviceType) {
1198  if (!arrayAttr)
1199  return {};
1200  if (auto pos = findSegment(*arrayAttr, deviceType))
1201  return range[*pos];
1202  return {};
1203 }
1204 
1205 bool acc::ParallelOp::hasAsyncOnly() {
1206  return hasAsyncOnly(mlir::acc::DeviceType::None);
1207 }
1208 
1209 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1210  return hasDeviceType(getAsyncOnly(), deviceType);
1211 }
1212 
1213 mlir::Value acc::ParallelOp::getAsyncValue() {
1214  return getAsyncValue(mlir::acc::DeviceType::None);
1215 }
1216 
1217 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1219  getAsyncOperands(), deviceType);
1220 }
1221 
1222 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1223  return getNumWorkersValue(mlir::acc::DeviceType::None);
1224 }
1225 
1227 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1228  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1229  deviceType);
1230 }
1231 
1232 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1233  return getVectorLengthValue(mlir::acc::DeviceType::None);
1234 }
1235 
1237 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1238  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1239  getVectorLength(), deviceType);
1240 }
1241 
1242 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1243  return getNumGangsValues(mlir::acc::DeviceType::None);
1244 }
1245 
1247 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1248  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1249  getNumGangsSegments(), deviceType);
1250 }
1251 
1252 bool acc::ParallelOp::hasWaitOnly() {
1253  return hasWaitOnly(mlir::acc::DeviceType::None);
1254 }
1255 
1256 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1257  return hasDeviceType(getWaitOnly(), deviceType);
1258 }
1259 
1260 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1261  return getWaitValues(mlir::acc::DeviceType::None);
1262 }
1263 
1265 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1267  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1268  getHasWaitDevnum(), deviceType);
1269 }
1270 
1271 mlir::Value ParallelOp::getWaitDevnum() {
1272  return getWaitDevnum(mlir::acc::DeviceType::None);
1273 }
1274 
1275 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1276  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1277  getWaitOperandsSegments(), getHasWaitDevnum(),
1278  deviceType);
1279 }
1280 
1281 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1282  mlir::OperationState &odsState,
1283  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1284  mlir::ValueRange vectorLength,
1285  mlir::ValueRange asyncOperands,
1286  mlir::ValueRange waitOperands, mlir::Value ifCond,
1287  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1288  mlir::ValueRange gangPrivateOperands,
1289  mlir::ValueRange gangFirstPrivateOperands,
1290  mlir::ValueRange dataClauseOperands) {
1291 
1292  ParallelOp::build(
1293  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1294  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1295  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1296  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1297  /*numGangsDeviceType=*/nullptr, numWorkers,
1298  /*numWorkersDeviceType=*/nullptr, vectorLength,
1299  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1300  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1301  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1302  /*firstprivatizations=*/nullptr, dataClauseOperands,
1303  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1304 }
1305 
1306 void acc::ParallelOp::addNumWorkersOperand(
1307  MLIRContext *context, mlir::Value newValue,
1308  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1309  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1310  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1311  getNumWorkersMutable()));
1312 }
1313 void acc::ParallelOp::addVectorLengthOperand(
1314  MLIRContext *context, mlir::Value newValue,
1315  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1316  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1317  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1318  getVectorLengthMutable()));
1319 }
1320 
1321 void acc::ParallelOp::addAsyncOnly(
1322  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1323  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1324  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1325 }
1326 
1327 void acc::ParallelOp::addAsyncOperand(
1328  MLIRContext *context, mlir::Value newValue,
1329  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1330  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1331  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1332  getAsyncOperandsMutable()));
1333 }
1334 
1335 void acc::ParallelOp::addNumGangsOperands(
1336  MLIRContext *context, mlir::ValueRange newValues,
1337  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1338  llvm::SmallVector<int32_t> segments;
1339  if (getNumGangsSegments())
1340  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1341 
1342  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1343  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1344  getNumGangsMutable(), segments));
1345 
1346  setNumGangsSegments(segments);
1347 }
1348 void acc::ParallelOp::addWaitOnly(
1349  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1350  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1351  effectiveDeviceTypes));
1352 }
1353 void acc::ParallelOp::addWaitOperands(
1354  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1355  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1356 
1357  llvm::SmallVector<int32_t> segments;
1358  if (getWaitOperandsSegments())
1359  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1360 
1361  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1362  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1363  getWaitOperandsMutable(), segments));
1364  setWaitOperandsSegments(segments);
1365 
1367  if (getHasWaitDevnumAttr())
1368  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1369  hasDevnums.insert(
1370  hasDevnums.end(),
1371  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1372  mlir::BoolAttr::get(context, hasDevnum));
1373  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1374 }
1375 
1376 void acc::ParallelOp::addPrivatization(MLIRContext *context,
1377  mlir::acc::PrivateOp op,
1378  mlir::acc::PrivateRecipeOp recipe) {
1379  getPrivateOperandsMutable().append(op.getResult());
1380 
1382 
1383  if (getPrivatizationRecipesAttr())
1384  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1385 
1386  recipes.push_back(
1387  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1388  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1389 }
1390 
1391 void acc::ParallelOp::addFirstPrivatization(
1392  MLIRContext *context, mlir::acc::FirstprivateOp op,
1393  mlir::acc::FirstprivateRecipeOp recipe) {
1394  getFirstprivateOperandsMutable().append(op.getResult());
1395 
1397 
1398  if (getFirstprivatizationRecipesAttr())
1399  llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1400 
1401  recipes.push_back(
1402  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1403  setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1404 }
1405 
1406 void acc::ParallelOp::addReduction(MLIRContext *context,
1407  mlir::acc::ReductionOp op,
1408  mlir::acc::ReductionRecipeOp recipe) {
1409  getReductionOperandsMutable().append(op.getResult());
1410 
1412 
1413  if (getReductionRecipesAttr())
1414  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1415 
1416  recipes.push_back(
1417  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1418  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1419 }
1420 
1421 static ParseResult parseNumGangs(
1422  mlir::OpAsmParser &parser,
1424  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1425  mlir::DenseI32ArrayAttr &segments) {
1428 
1429  do {
1430  if (failed(parser.parseLBrace()))
1431  return failure();
1432 
1433  int32_t crtOperandsSize = operands.size();
1434  if (failed(parser.parseCommaSeparatedList(
1436  if (parser.parseOperand(operands.emplace_back()) ||
1437  parser.parseColonType(types.emplace_back()))
1438  return failure();
1439  return success();
1440  })))
1441  return failure();
1442  seg.push_back(operands.size() - crtOperandsSize);
1443 
1444  if (failed(parser.parseRBrace()))
1445  return failure();
1446 
1447  if (succeeded(parser.parseOptionalLSquare())) {
1448  if (parser.parseAttribute(attributes.emplace_back()) ||
1449  parser.parseRSquare())
1450  return failure();
1451  } else {
1452  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1454  }
1455  } while (succeeded(parser.parseOptionalComma()));
1456 
1457  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1458  attributes.end());
1459  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1460  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1461 
1462  return success();
1463 }
1464 
1466  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1467  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1468  p << " [" << attr << "]";
1469 }
1470 
1472  mlir::OperandRange operands, mlir::TypeRange types,
1473  std::optional<mlir::ArrayAttr> deviceTypes,
1474  std::optional<mlir::DenseI32ArrayAttr> segments) {
1475  unsigned opIdx = 0;
1476  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1477  p << "{";
1478  llvm::interleaveComma(
1479  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1480  p << operands[opIdx] << " : " << operands[opIdx].getType();
1481  ++opIdx;
1482  });
1483  p << "}";
1484  printSingleDeviceType(p, it.value());
1485  });
1486 }
1487 
1489  mlir::OpAsmParser &parser,
1491  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1492  mlir::DenseI32ArrayAttr &segments) {
1495 
1496  do {
1497  if (failed(parser.parseLBrace()))
1498  return failure();
1499 
1500  int32_t crtOperandsSize = operands.size();
1501 
1502  if (failed(parser.parseCommaSeparatedList(
1504  if (parser.parseOperand(operands.emplace_back()) ||
1505  parser.parseColonType(types.emplace_back()))
1506  return failure();
1507  return success();
1508  })))
1509  return failure();
1510 
1511  seg.push_back(operands.size() - crtOperandsSize);
1512 
1513  if (failed(parser.parseRBrace()))
1514  return failure();
1515 
1516  if (succeeded(parser.parseOptionalLSquare())) {
1517  if (parser.parseAttribute(attributes.emplace_back()) ||
1518  parser.parseRSquare())
1519  return failure();
1520  } else {
1521  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1523  }
1524  } while (succeeded(parser.parseOptionalComma()));
1525 
1526  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1527  attributes.end());
1528  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1529  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1530 
1531  return success();
1532 }
1533 
1536  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1537  std::optional<mlir::DenseI32ArrayAttr> segments) {
1538  unsigned opIdx = 0;
1539  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1540  p << "{";
1541  llvm::interleaveComma(
1542  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1543  p << operands[opIdx] << " : " << operands[opIdx].getType();
1544  ++opIdx;
1545  });
1546  p << "}";
1547  printSingleDeviceType(p, it.value());
1548  });
1549 }
1550 
1551 static ParseResult parseWaitClause(
1552  mlir::OpAsmParser &parser,
1554  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1555  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1556  mlir::ArrayAttr &keywordOnly) {
1557  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1559 
1560  bool needCommaBeforeOperands = false;
1561 
1562  // Keyword only
1563  if (failed(parser.parseOptionalLParen())) {
1564  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1566  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1567  return success();
1568  }
1569 
1570  // Parse keyword only attributes
1571  if (succeeded(parser.parseOptionalLSquare())) {
1572  if (failed(parser.parseCommaSeparatedList([&]() {
1573  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1574  return failure();
1575  return success();
1576  })))
1577  return failure();
1578  if (parser.parseRSquare())
1579  return failure();
1580  needCommaBeforeOperands = true;
1581  }
1582 
1583  if (needCommaBeforeOperands && failed(parser.parseComma()))
1584  return failure();
1585 
1586  do {
1587  if (failed(parser.parseLBrace()))
1588  return failure();
1589 
1590  int32_t crtOperandsSize = operands.size();
1591 
1592  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1593  if (failed(parser.parseColon()))
1594  return failure();
1595  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1596  } else {
1597  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1598  }
1599 
1600  if (failed(parser.parseCommaSeparatedList(
1602  if (parser.parseOperand(operands.emplace_back()) ||
1603  parser.parseColonType(types.emplace_back()))
1604  return failure();
1605  return success();
1606  })))
1607  return failure();
1608 
1609  seg.push_back(operands.size() - crtOperandsSize);
1610 
1611  if (failed(parser.parseRBrace()))
1612  return failure();
1613 
1614  if (succeeded(parser.parseOptionalLSquare())) {
1615  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1616  parser.parseRSquare())
1617  return failure();
1618  } else {
1619  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1621  }
1622  } while (succeeded(parser.parseOptionalComma()));
1623 
1624  if (failed(parser.parseRParen()))
1625  return failure();
1626 
1627  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1628  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1629  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1630  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1631 
1632  return success();
1633 }
1634 
1635 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1636  if (!hasDeviceTypeValues(attrs))
1637  return false;
1638  if (attrs->size() != 1)
1639  return false;
1640  if (auto deviceTypeAttr =
1641  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1642  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1643  return false;
1644 }
1645 
1647  mlir::OperandRange operands, mlir::TypeRange types,
1648  std::optional<mlir::ArrayAttr> deviceTypes,
1649  std::optional<mlir::DenseI32ArrayAttr> segments,
1650  std::optional<mlir::ArrayAttr> hasDevNum,
1651  std::optional<mlir::ArrayAttr> keywordOnly) {
1652 
1653  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1654  return;
1655 
1656  p << "(";
1657 
1658  printDeviceTypes(p, keywordOnly);
1659  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1660  p << ", ";
1661 
1662  if (hasDeviceTypeValues(deviceTypes)) {
1663  unsigned opIdx = 0;
1664  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1665  p << "{";
1666  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1667  if (boolAttr && boolAttr.getValue())
1668  p << "devnum: ";
1669  llvm::interleaveComma(
1670  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1671  p << operands[opIdx] << " : " << operands[opIdx].getType();
1672  ++opIdx;
1673  });
1674  p << "}";
1675  printSingleDeviceType(p, it.value());
1676  });
1677  }
1678 
1679  p << ")";
1680 }
1681 
1682 static ParseResult parseDeviceTypeOperands(
1683  mlir::OpAsmParser &parser,
1685  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1687  if (failed(parser.parseCommaSeparatedList([&]() {
1688  if (parser.parseOperand(operands.emplace_back()) ||
1689  parser.parseColonType(types.emplace_back()))
1690  return failure();
1691  if (succeeded(parser.parseOptionalLSquare())) {
1692  if (parser.parseAttribute(attributes.emplace_back()) ||
1693  parser.parseRSquare())
1694  return failure();
1695  } else {
1696  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1697  parser.getContext(), mlir::acc::DeviceType::None));
1698  }
1699  return success();
1700  })))
1701  return failure();
1702  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1703  attributes.end());
1704  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1705  return success();
1706 }
1707 
1708 static void
1710  mlir::OperandRange operands, mlir::TypeRange types,
1711  std::optional<mlir::ArrayAttr> deviceTypes) {
1712  if (!hasDeviceTypeValues(deviceTypes))
1713  return;
1714  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1715  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1716  printSingleDeviceType(p, std::get<0>(it));
1717  });
1718 }
1719 
1721  mlir::OpAsmParser &parser,
1723  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1724  mlir::ArrayAttr &keywordOnlyDeviceType) {
1725 
1726  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1727  bool needCommaBeforeOperands = false;
1728 
1729  if (failed(parser.parseOptionalLParen())) {
1730  // Keyword only
1731  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1733  keywordOnlyDeviceType =
1734  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1735  return success();
1736  }
1737 
1738  // Parse keyword only attributes
1739  if (succeeded(parser.parseOptionalLSquare())) {
1740  // Parse keyword only attributes
1741  if (failed(parser.parseCommaSeparatedList([&]() {
1742  if (parser.parseAttribute(
1743  keywordOnlyDeviceTypeAttributes.emplace_back()))
1744  return failure();
1745  return success();
1746  })))
1747  return failure();
1748  if (parser.parseRSquare())
1749  return failure();
1750  needCommaBeforeOperands = true;
1751  }
1752 
1753  if (needCommaBeforeOperands && failed(parser.parseComma()))
1754  return failure();
1755 
1757  if (failed(parser.parseCommaSeparatedList([&]() {
1758  if (parser.parseOperand(operands.emplace_back()) ||
1759  parser.parseColonType(types.emplace_back()))
1760  return failure();
1761  if (succeeded(parser.parseOptionalLSquare())) {
1762  if (parser.parseAttribute(attributes.emplace_back()) ||
1763  parser.parseRSquare())
1764  return failure();
1765  } else {
1766  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1767  parser.getContext(), mlir::acc::DeviceType::None));
1768  }
1769  return success();
1770  })))
1771  return failure();
1772 
1773  if (failed(parser.parseRParen()))
1774  return failure();
1775 
1776  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1777  attributes.end());
1778  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1779  return success();
1780 }
1781 
1784  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1785  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1786 
1787  if (operands.begin() == operands.end() &&
1788  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1789  return;
1790  }
1791 
1792  p << "(";
1793  printDeviceTypes(p, keywordOnlyDeviceTypes);
1794  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1795  hasDeviceTypeValues(deviceTypes))
1796  p << ", ";
1797  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1798  p << ")";
1799 }
1800 
1801 static ParseResult parseOperandWithKeywordOnly(
1802  mlir::OpAsmParser &parser,
1803  std::optional<OpAsmParser::UnresolvedOperand> &operand,
1804  mlir::Type &operandType, mlir::UnitAttr &attr) {
1805  // Keyword only
1806  if (failed(parser.parseOptionalLParen())) {
1807  attr = mlir::UnitAttr::get(parser.getContext());
1808  return success();
1809  }
1810 
1812  if (failed(parser.parseOperand(op)))
1813  return failure();
1814  operand = op;
1815  if (failed(parser.parseColon()))
1816  return failure();
1817  if (failed(parser.parseType(operandType)))
1818  return failure();
1819  if (failed(parser.parseRParen()))
1820  return failure();
1821 
1822  return success();
1823 }
1824 
1826  mlir::Operation *op,
1827  std::optional<mlir::Value> operand,
1828  mlir::Type operandType,
1829  mlir::UnitAttr attr) {
1830  if (attr)
1831  return;
1832 
1833  p << "(";
1834  p.printOperand(*operand);
1835  p << " : ";
1836  p.printType(operandType);
1837  p << ")";
1838 }
1839 
1840 static ParseResult parseOperandsWithKeywordOnly(
1841  mlir::OpAsmParser &parser,
1843  llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
1844  // Keyword only
1845  if (failed(parser.parseOptionalLParen())) {
1846  attr = mlir::UnitAttr::get(parser.getContext());
1847  return success();
1848  }
1849 
1850  if (failed(parser.parseCommaSeparatedList([&]() {
1851  if (parser.parseOperand(operands.emplace_back()))
1852  return failure();
1853  return success();
1854  })))
1855  return failure();
1856  if (failed(parser.parseColon()))
1857  return failure();
1858  if (failed(parser.parseCommaSeparatedList([&]() {
1859  if (parser.parseType(types.emplace_back()))
1860  return failure();
1861  return success();
1862  })))
1863  return failure();
1864  if (failed(parser.parseRParen()))
1865  return failure();
1866 
1867  return success();
1868 }
1869 
1871  mlir::Operation *op,
1872  mlir::OperandRange operands,
1873  mlir::TypeRange types,
1874  mlir::UnitAttr attr) {
1875  if (attr)
1876  return;
1877 
1878  p << "(";
1879  llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
1880  p << " : ";
1881  llvm::interleaveComma(types, p, [&](auto it) { p << it; });
1882  p << ")";
1883 }
1884 
1885 static ParseResult
1887  mlir::acc::CombinedConstructsTypeAttr &attr) {
1888  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1890  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1891  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1893  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1894  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1896  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1897  } else {
1898  parser.emitError(parser.getCurrentLocation(),
1899  "expected compute construct name");
1900  return failure();
1901  }
1902  return success();
1903 }
1904 
1905 static void
1907  mlir::acc::CombinedConstructsTypeAttr attr) {
1908  if (attr) {
1909  switch (attr.getValue()) {
1910  case mlir::acc::CombinedConstructsType::KernelsLoop:
1911  p << "kernels";
1912  break;
1913  case mlir::acc::CombinedConstructsType::ParallelLoop:
1914  p << "parallel";
1915  break;
1916  case mlir::acc::CombinedConstructsType::SerialLoop:
1917  p << "serial";
1918  break;
1919  };
1920  }
1921 }
1922 
1923 //===----------------------------------------------------------------------===//
1924 // SerialOp
1925 //===----------------------------------------------------------------------===//
1926 
1927 unsigned SerialOp::getNumDataOperands() {
1928  return getReductionOperands().size() + getPrivateOperands().size() +
1929  getFirstprivateOperands().size() + getDataClauseOperands().size();
1930 }
1931 
1932 Value SerialOp::getDataOperand(unsigned i) {
1933  unsigned numOptional = getAsyncOperands().size();
1934  numOptional += getIfCond() ? 1 : 0;
1935  numOptional += getSelfCond() ? 1 : 0;
1936  return getOperand(getWaitOperands().size() + numOptional + i);
1937 }
1938 
1939 bool acc::SerialOp::hasAsyncOnly() {
1940  return hasAsyncOnly(mlir::acc::DeviceType::None);
1941 }
1942 
1943 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1944  return hasDeviceType(getAsyncOnly(), deviceType);
1945 }
1946 
1947 mlir::Value acc::SerialOp::getAsyncValue() {
1948  return getAsyncValue(mlir::acc::DeviceType::None);
1949 }
1950 
1951 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1953  getAsyncOperands(), deviceType);
1954 }
1955 
1956 bool acc::SerialOp::hasWaitOnly() {
1957  return hasWaitOnly(mlir::acc::DeviceType::None);
1958 }
1959 
1960 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1961  return hasDeviceType(getWaitOnly(), deviceType);
1962 }
1963 
1964 mlir::Operation::operand_range SerialOp::getWaitValues() {
1965  return getWaitValues(mlir::acc::DeviceType::None);
1966 }
1967 
1969 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1971  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1972  getHasWaitDevnum(), deviceType);
1973 }
1974 
1975 mlir::Value SerialOp::getWaitDevnum() {
1976  return getWaitDevnum(mlir::acc::DeviceType::None);
1977 }
1978 
1979 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1980  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1981  getWaitOperandsSegments(), getHasWaitDevnum(),
1982  deviceType);
1983 }
1984 
1985 LogicalResult acc::SerialOp::verify() {
1986  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1987  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1988  "privatizations", /*checkOperandType=*/false)))
1989  return failure();
1990  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1991  *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1992  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1993  return failure();
1994  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1995  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1996  "reductions", false)))
1997  return failure();
1998 
2000  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2001  getWaitOperandsDeviceTypeAttr(), "wait")))
2002  return failure();
2003 
2005  getAsyncOperandsDeviceTypeAttr(),
2006  "async")))
2007  return failure();
2008 
2009  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
2010  return failure();
2011 
2012  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2013 }
2014 
2015 void acc::SerialOp::addAsyncOnly(
2016  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2017  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2018  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2019 }
2020 
2021 void acc::SerialOp::addAsyncOperand(
2022  MLIRContext *context, mlir::Value newValue,
2023  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2024  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2025  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2026  getAsyncOperandsMutable()));
2027 }
2028 
2029 void acc::SerialOp::addWaitOnly(
2030  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2031  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2032  effectiveDeviceTypes));
2033 }
2034 void acc::SerialOp::addWaitOperands(
2035  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2036  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2037 
2038  llvm::SmallVector<int32_t> segments;
2039  if (getWaitOperandsSegments())
2040  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2041 
2042  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2043  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2044  getWaitOperandsMutable(), segments));
2045  setWaitOperandsSegments(segments);
2046 
2048  if (getHasWaitDevnumAttr())
2049  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2050  hasDevnums.insert(
2051  hasDevnums.end(),
2052  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2053  mlir::BoolAttr::get(context, hasDevnum));
2054  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2055 }
2056 
2057 void acc::SerialOp::addPrivatization(MLIRContext *context,
2058  mlir::acc::PrivateOp op,
2059  mlir::acc::PrivateRecipeOp recipe) {
2060  getPrivateOperandsMutable().append(op.getResult());
2061 
2063 
2064  if (getPrivatizationRecipesAttr())
2065  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2066 
2067  recipes.push_back(
2068  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2069  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2070 }
2071 
2072 void acc::SerialOp::addFirstPrivatization(
2073  MLIRContext *context, mlir::acc::FirstprivateOp op,
2074  mlir::acc::FirstprivateRecipeOp recipe) {
2075  getFirstprivateOperandsMutable().append(op.getResult());
2076 
2078 
2079  if (getFirstprivatizationRecipesAttr())
2080  llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2081 
2082  recipes.push_back(
2083  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2084  setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2085 }
2086 
2087 void acc::SerialOp::addReduction(MLIRContext *context,
2088  mlir::acc::ReductionOp op,
2089  mlir::acc::ReductionRecipeOp recipe) {
2090  getReductionOperandsMutable().append(op.getResult());
2091 
2093 
2094  if (getReductionRecipesAttr())
2095  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2096 
2097  recipes.push_back(
2098  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2099  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2100 }
2101 
2102 //===----------------------------------------------------------------------===//
2103 // KernelsOp
2104 //===----------------------------------------------------------------------===//
2105 
2106 unsigned KernelsOp::getNumDataOperands() {
2107  return getDataClauseOperands().size();
2108 }
2109 
2110 Value KernelsOp::getDataOperand(unsigned i) {
2111  unsigned numOptional = getAsyncOperands().size();
2112  numOptional += getWaitOperands().size();
2113  numOptional += getNumGangs().size();
2114  numOptional += getNumWorkers().size();
2115  numOptional += getVectorLength().size();
2116  numOptional += getIfCond() ? 1 : 0;
2117  numOptional += getSelfCond() ? 1 : 0;
2118  return getOperand(numOptional + i);
2119 }
2120 
2121 bool acc::KernelsOp::hasAsyncOnly() {
2122  return hasAsyncOnly(mlir::acc::DeviceType::None);
2123 }
2124 
2125 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2126  return hasDeviceType(getAsyncOnly(), deviceType);
2127 }
2128 
2129 mlir::Value acc::KernelsOp::getAsyncValue() {
2130  return getAsyncValue(mlir::acc::DeviceType::None);
2131 }
2132 
2133 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2135  getAsyncOperands(), deviceType);
2136 }
2137 
2138 mlir::Value acc::KernelsOp::getNumWorkersValue() {
2139  return getNumWorkersValue(mlir::acc::DeviceType::None);
2140 }
2141 
2143 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2144  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2145  deviceType);
2146 }
2147 
2148 mlir::Value acc::KernelsOp::getVectorLengthValue() {
2149  return getVectorLengthValue(mlir::acc::DeviceType::None);
2150 }
2151 
2153 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2154  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2155  getVectorLength(), deviceType);
2156 }
2157 
2158 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2159  return getNumGangsValues(mlir::acc::DeviceType::None);
2160 }
2161 
2163 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2164  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2165  getNumGangsSegments(), deviceType);
2166 }
2167 
2168 bool acc::KernelsOp::hasWaitOnly() {
2169  return hasWaitOnly(mlir::acc::DeviceType::None);
2170 }
2171 
2172 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2173  return hasDeviceType(getWaitOnly(), deviceType);
2174 }
2175 
2176 mlir::Operation::operand_range KernelsOp::getWaitValues() {
2177  return getWaitValues(mlir::acc::DeviceType::None);
2178 }
2179 
2181 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2183  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2184  getHasWaitDevnum(), deviceType);
2185 }
2186 
2187 mlir::Value KernelsOp::getWaitDevnum() {
2188  return getWaitDevnum(mlir::acc::DeviceType::None);
2189 }
2190 
2191 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2192  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2193  getWaitOperandsSegments(), getHasWaitDevnum(),
2194  deviceType);
2195 }
2196 
2197 LogicalResult acc::KernelsOp::verify() {
2199  *this, getNumGangs(), getNumGangsSegmentsAttr(),
2200  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2201  return failure();
2202 
2204  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2205  getWaitOperandsDeviceTypeAttr(), "wait")))
2206  return failure();
2207 
2208  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2209  getNumWorkersDeviceTypeAttr(),
2210  "num_workers")))
2211  return failure();
2212 
2213  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2214  getVectorLengthDeviceTypeAttr(),
2215  "vector_length")))
2216  return failure();
2217 
2219  getAsyncOperandsDeviceTypeAttr(),
2220  "async")))
2221  return failure();
2222 
2223  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
2224  return failure();
2225 
2226  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2227 }
2228 
2229 void acc::KernelsOp::addNumWorkersOperand(
2230  MLIRContext *context, mlir::Value newValue,
2231  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2232  setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2233  context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2234  getNumWorkersMutable()));
2235 }
2236 
2237 void acc::KernelsOp::addVectorLengthOperand(
2238  MLIRContext *context, mlir::Value newValue,
2239  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2240  setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2241  context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2242  getVectorLengthMutable()));
2243 }
2244 void acc::KernelsOp::addAsyncOnly(
2245  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2246  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2247  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2248 }
2249 
2250 void acc::KernelsOp::addAsyncOperand(
2251  MLIRContext *context, mlir::Value newValue,
2252  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2253  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2254  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2255  getAsyncOperandsMutable()));
2256 }
2257 
2258 void acc::KernelsOp::addNumGangsOperands(
2259  MLIRContext *context, mlir::ValueRange newValues,
2260  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2261  llvm::SmallVector<int32_t> segments;
2262  if (getNumGangsSegmentsAttr())
2263  llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2264 
2265  setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2266  context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2267  getNumGangsMutable(), segments));
2268 
2269  setNumGangsSegments(segments);
2270 }
2271 
2272 void acc::KernelsOp::addWaitOnly(
2273  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2274  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2275  effectiveDeviceTypes));
2276 }
2277 void acc::KernelsOp::addWaitOperands(
2278  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2279  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2280 
2281  llvm::SmallVector<int32_t> segments;
2282  if (getWaitOperandsSegments())
2283  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2284 
2285  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2286  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2287  getWaitOperandsMutable(), segments));
2288  setWaitOperandsSegments(segments);
2289 
2291  if (getHasWaitDevnumAttr())
2292  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2293  hasDevnums.insert(
2294  hasDevnums.end(),
2295  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2296  mlir::BoolAttr::get(context, hasDevnum));
2297  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2298 }
2299 
2300 //===----------------------------------------------------------------------===//
2301 // HostDataOp
2302 //===----------------------------------------------------------------------===//
2303 
2304 LogicalResult acc::HostDataOp::verify() {
2305  if (getDataClauseOperands().empty())
2306  return emitError("at least one operand must appear on the host_data "
2307  "operation");
2308 
2309  for (mlir::Value operand : getDataClauseOperands())
2310  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2311  return emitError("expect data entry operation as defining op");
2312  return success();
2313 }
2314 
2315 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2316  MLIRContext *context) {
2317  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2318 }
2319 
2320 //===----------------------------------------------------------------------===//
2321 // LoopOp
2322 //===----------------------------------------------------------------------===//
2323 
2324 static ParseResult parseGangValue(
2325  OpAsmParser &parser, llvm::StringRef keyword,
2328  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2329  bool &needCommaBetweenValues, bool &newValue) {
2330  if (succeeded(parser.parseOptionalKeyword(keyword))) {
2331  if (parser.parseEqual())
2332  return failure();
2333  if (parser.parseOperand(operands.emplace_back()) ||
2334  parser.parseColonType(types.emplace_back()))
2335  return failure();
2336  attributes.push_back(gangArgType);
2337  needCommaBetweenValues = true;
2338  newValue = true;
2339  }
2340  return success();
2341 }
2342 
2343 static ParseResult parseGangClause(
2344  OpAsmParser &parser,
2346  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2347  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2348  mlir::ArrayAttr &gangOnlyDeviceType) {
2349  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2350  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2351  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2353  bool needCommaBetweenValues = false;
2354  bool needCommaBeforeOperands = false;
2355 
2356  if (failed(parser.parseOptionalLParen())) {
2357  // Gang only keyword
2358  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2360  gangOnlyDeviceType =
2361  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2362  return success();
2363  }
2364 
2365  // Parse gang only attributes
2366  if (succeeded(parser.parseOptionalLSquare())) {
2367  // Parse gang only attributes
2368  if (failed(parser.parseCommaSeparatedList([&]() {
2369  if (parser.parseAttribute(
2370  gangOnlyDeviceTypeAttributes.emplace_back()))
2371  return failure();
2372  return success();
2373  })))
2374  return failure();
2375  if (parser.parseRSquare())
2376  return failure();
2377  needCommaBeforeOperands = true;
2378  }
2379 
2380  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2381  mlir::acc::GangArgType::Num);
2382  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2383  mlir::acc::GangArgType::Dim);
2384  auto argStatic = mlir::acc::GangArgTypeAttr::get(
2385  parser.getContext(), mlir::acc::GangArgType::Static);
2386 
2387  do {
2388  if (needCommaBeforeOperands) {
2389  needCommaBeforeOperands = false;
2390  continue;
2391  }
2392 
2393  if (failed(parser.parseLBrace()))
2394  return failure();
2395 
2396  int32_t crtOperandsSize = gangOperands.size();
2397  while (true) {
2398  bool newValue = false;
2399  bool needValue = false;
2400  if (needCommaBetweenValues) {
2401  if (succeeded(parser.parseOptionalComma()))
2402  needValue = true; // expect a new value after comma.
2403  else
2404  break;
2405  }
2406 
2407  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2408  gangOperands, gangOperandsType,
2409  gangArgTypeAttributes, argNum,
2410  needCommaBetweenValues, newValue)))
2411  return failure();
2412  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2413  gangOperands, gangOperandsType,
2414  gangArgTypeAttributes, argDim,
2415  needCommaBetweenValues, newValue)))
2416  return failure();
2417  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2418  gangOperands, gangOperandsType,
2419  gangArgTypeAttributes, argStatic,
2420  needCommaBetweenValues, newValue)))
2421  return failure();
2422 
2423  if (!newValue && needValue) {
2424  parser.emitError(parser.getCurrentLocation(),
2425  "new value expected after comma");
2426  return failure();
2427  }
2428 
2429  if (!newValue)
2430  break;
2431  }
2432 
2433  if (gangOperands.empty())
2434  return parser.emitError(
2435  parser.getCurrentLocation(),
2436  "expect at least one of num, dim or static values");
2437 
2438  if (failed(parser.parseRBrace()))
2439  return failure();
2440 
2441  if (succeeded(parser.parseOptionalLSquare())) {
2442  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2443  parser.parseRSquare())
2444  return failure();
2445  } else {
2446  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2448  }
2449 
2450  seg.push_back(gangOperands.size() - crtOperandsSize);
2451 
2452  } while (succeeded(parser.parseOptionalComma()));
2453 
2454  if (failed(parser.parseRParen()))
2455  return failure();
2456 
2457  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2458  gangArgTypeAttributes.end());
2459  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2460  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2461 
2463  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2464  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2465 
2466  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2467  return success();
2468 }
2469 
2471  mlir::OperandRange operands, mlir::TypeRange types,
2472  std::optional<mlir::ArrayAttr> gangArgTypes,
2473  std::optional<mlir::ArrayAttr> deviceTypes,
2474  std::optional<mlir::DenseI32ArrayAttr> segments,
2475  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2476 
2477  if (operands.begin() == operands.end() &&
2478  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2479  return;
2480  }
2481 
2482  p << "(";
2483 
2484  printDeviceTypes(p, gangOnlyDeviceTypes);
2485 
2486  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2487  hasDeviceTypeValues(deviceTypes))
2488  p << ", ";
2489 
2490  if (hasDeviceTypeValues(deviceTypes)) {
2491  unsigned opIdx = 0;
2492  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2493  p << "{";
2494  llvm::interleaveComma(
2495  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2496  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2497  (*gangArgTypes)[opIdx]);
2498  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2499  p << LoopOp::getGangNumKeyword();
2500  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2501  p << LoopOp::getGangDimKeyword();
2502  else if (gangArgTypeAttr.getValue() ==
2503  mlir::acc::GangArgType::Static)
2504  p << LoopOp::getGangStaticKeyword();
2505  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2506  ++opIdx;
2507  });
2508  p << "}";
2509  printSingleDeviceType(p, it.value());
2510  });
2511  }
2512  p << ")";
2513 }
2514 
2516  std::optional<mlir::ArrayAttr> segments,
2517  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2518  if (!segments)
2519  return false;
2520  for (auto attr : *segments) {
2521  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2522  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2523  return true;
2524  }
2525  return false;
2526 }
2527 
2528 /// Check for duplicates in the DeviceType array attribute.
2529 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2530  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2531  if (!deviceTypes)
2532  return success();
2533  for (auto attr : deviceTypes) {
2534  auto deviceTypeAttr =
2535  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2536  if (!deviceTypeAttr)
2537  return failure();
2538  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2539  return failure();
2540  }
2541  return success();
2542 }
2543 
2544 LogicalResult acc::LoopOp::verify() {
2545  if (getUpperbound().size() != getStep().size())
2546  return emitError() << "number of upperbounds expected to be the same as "
2547  "number of steps";
2548 
2549  if (getUpperbound().size() != getLowerbound().size())
2550  return emitError() << "number of upperbounds expected to be the same as "
2551  "number of lowerbounds";
2552 
2553  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2554  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2555  return emitError() << "inclusiveUpperbound size is expected to be the same"
2556  << " as upperbound size";
2557 
2558  // Check collapse
2559  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2560  return emitOpError() << "collapse device_type attr must be define when"
2561  << " collapse attr is present";
2562 
2563  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2564  getCollapseAttr().getValue().size() !=
2565  getCollapseDeviceTypeAttr().getValue().size())
2566  return emitOpError() << "collapse attribute count must match collapse"
2567  << " device_type count";
2568  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2569  return emitOpError()
2570  << "duplicate device_type found in collapseDeviceType attribute";
2571 
2572  // Check gang
2573  if (!getGangOperands().empty()) {
2574  if (!getGangOperandsArgType())
2575  return emitOpError() << "gangOperandsArgType attribute must be defined"
2576  << " when gang operands are present";
2577 
2578  if (getGangOperands().size() !=
2579  getGangOperandsArgTypeAttr().getValue().size())
2580  return emitOpError() << "gangOperandsArgType attribute count must match"
2581  << " gangOperands count";
2582  }
2583  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2584  return emitOpError() << "duplicate device_type found in gang attribute";
2585 
2587  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2588  getGangOperandsDeviceTypeAttr(), "gang")))
2589  return failure();
2590 
2591  // Check worker
2592  if (failed(checkDeviceTypes(getWorkerAttr())))
2593  return emitOpError() << "duplicate device_type found in worker attribute";
2594  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2595  return emitOpError() << "duplicate device_type found in "
2596  "workerNumOperandsDeviceType attribute";
2597  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2598  getWorkerNumOperandsDeviceTypeAttr(),
2599  "worker")))
2600  return failure();
2601 
2602  // Check vector
2603  if (failed(checkDeviceTypes(getVectorAttr())))
2604  return emitOpError() << "duplicate device_type found in vector attribute";
2605  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2606  return emitOpError() << "duplicate device_type found in "
2607  "vectorOperandsDeviceType attribute";
2608  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2609  getVectorOperandsDeviceTypeAttr(),
2610  "vector")))
2611  return failure();
2612 
2614  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2615  getTileOperandsDeviceTypeAttr(), "tile")))
2616  return failure();
2617 
2618  // auto, independent and seq attribute are mutually exclusive.
2619  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2620  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2621  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2622  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2623  return emitError() << "only one of auto, independent, seq can be present "
2624  "at the same time";
2625  }
2626 
2627  // Check that at least one of auto, independent, or seq is present
2628  // for the device-independent default clauses.
2629  auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
2630  return attr.getValue() == mlir::acc::DeviceType::None;
2631  };
2632  bool hasDefaultSeq =
2633  getSeqAttr()
2634  ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2635  hasDeviceNone)
2636  : false;
2637  bool hasDefaultIndependent =
2638  getIndependentAttr()
2639  ? llvm::any_of(
2640  getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2641  hasDeviceNone)
2642  : false;
2643  bool hasDefaultAuto =
2644  getAuto_Attr()
2645  ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2646  hasDeviceNone)
2647  : false;
2648  if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2649  return emitError()
2650  << "at least one of auto, independent, seq must be present";
2651  }
2652 
2653  // Gang, worker and vector are incompatible with seq.
2654  if (getSeqAttr()) {
2655  for (auto attr : getSeqAttr()) {
2656  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2657  if (hasVector(deviceTypeAttr.getValue()) ||
2658  getVectorValue(deviceTypeAttr.getValue()) ||
2659  hasWorker(deviceTypeAttr.getValue()) ||
2660  getWorkerValue(deviceTypeAttr.getValue()) ||
2661  hasGang(deviceTypeAttr.getValue()) ||
2662  getGangValue(mlir::acc::GangArgType::Num,
2663  deviceTypeAttr.getValue()) ||
2664  getGangValue(mlir::acc::GangArgType::Dim,
2665  deviceTypeAttr.getValue()) ||
2666  getGangValue(mlir::acc::GangArgType::Static,
2667  deviceTypeAttr.getValue()))
2668  return emitError() << "gang, worker or vector cannot appear with seq";
2669  }
2670  }
2671 
2672  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2673  *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2674  "privatizations", false)))
2675  return failure();
2676 
2677  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2678  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2679  "reductions", false)))
2680  return failure();
2681 
2682  if (getCombined().has_value() &&
2683  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2684  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2685  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2686  return emitError("unexpected combined constructs attribute");
2687  }
2688 
2689  // Check non-empty body().
2690  if (getRegion().empty())
2691  return emitError("expected non-empty body.");
2692 
2693  // When it is container-like - it is expected to hold a loop-like operation.
2694  if (isContainerLike()) {
2695  // Obtain the maximum collapse count - we use this to check that there
2696  // are enough loops contained.
2697  uint64_t collapseCount = getCollapseValue().value_or(1);
2698  if (getCollapseAttr()) {
2699  for (auto collapseEntry : getCollapseAttr()) {
2700  auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2701  if (intAttr.getValue().getZExtValue() > collapseCount)
2702  collapseCount = intAttr.getValue().getZExtValue();
2703  }
2704  }
2705 
2706  // We want to check that we find enough loop-like operations inside.
2707  // PreOrder walk allows us to walk in a breadth-first manner at each nesting
2708  // level.
2709  mlir::Operation *expectedParent = this->getOperation();
2710  bool foundSibling = false;
2711  getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
2712  if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2713  // This effectively checks that we are not looking at a sibling loop.
2714  if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2715  expectedParent) {
2716  foundSibling = true;
2717  return mlir::WalkResult::interrupt();
2718  }
2719 
2720  collapseCount--;
2721  expectedParent = op;
2722  }
2723  // We found enough contained loops.
2724  if (collapseCount == 0)
2725  return mlir::WalkResult::interrupt();
2726  return mlir::WalkResult::advance();
2727  });
2728 
2729  if (foundSibling)
2730  return emitError("found sibling loops inside container-like acc.loop");
2731  if (collapseCount != 0)
2732  return emitError("failed to find enough loop-like operations inside "
2733  "container-like acc.loop");
2734  }
2735 
2736  return success();
2737 }
2738 
2739 unsigned LoopOp::getNumDataOperands() {
2740  return getReductionOperands().size() + getPrivateOperands().size();
2741 }
2742 
2743 Value LoopOp::getDataOperand(unsigned i) {
2744  unsigned numOptional =
2745  getLowerbound().size() + getUpperbound().size() + getStep().size();
2746  numOptional += getGangOperands().size();
2747  numOptional += getVectorOperands().size();
2748  numOptional += getWorkerNumOperands().size();
2749  numOptional += getTileOperands().size();
2750  numOptional += getCacheOperands().size();
2751  return getOperand(numOptional + i);
2752 }
2753 
2754 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2755 
2756 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2757  return hasDeviceType(getAuto_(), deviceType);
2758 }
2759 
2760 bool LoopOp::hasIndependent() {
2761  return hasIndependent(mlir::acc::DeviceType::None);
2762 }
2763 
2764 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2765  return hasDeviceType(getIndependent(), deviceType);
2766 }
2767 
2768 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2769 
2770 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2771  return hasDeviceType(getSeq(), deviceType);
2772 }
2773 
2774 mlir::Value LoopOp::getVectorValue() {
2775  return getVectorValue(mlir::acc::DeviceType::None);
2776 }
2777 
2778 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2779  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2780  getVectorOperands(), deviceType);
2781 }
2782 
2783 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2784 
2785 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2786  return hasDeviceType(getVector(), deviceType);
2787 }
2788 
2789 mlir::Value LoopOp::getWorkerValue() {
2790  return getWorkerValue(mlir::acc::DeviceType::None);
2791 }
2792 
2793 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2794  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2795  getWorkerNumOperands(), deviceType);
2796 }
2797 
2798 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2799 
2800 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2801  return hasDeviceType(getWorker(), deviceType);
2802 }
2803 
2804 mlir::Operation::operand_range LoopOp::getTileValues() {
2805  return getTileValues(mlir::acc::DeviceType::None);
2806 }
2807 
2809 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2810  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2811  getTileOperandsSegments(), deviceType);
2812 }
2813 
2814 std::optional<int64_t> LoopOp::getCollapseValue() {
2815  return getCollapseValue(mlir::acc::DeviceType::None);
2816 }
2817 
2818 std::optional<int64_t>
2819 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2820  if (!getCollapseAttr())
2821  return std::nullopt;
2822  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2823  auto intAttr =
2824  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2825  return intAttr.getValue().getZExtValue();
2826  }
2827  return std::nullopt;
2828 }
2829 
2830 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2831  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2832 }
2833 
2834 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2835  mlir::acc::DeviceType deviceType) {
2836  if (getGangOperands().empty())
2837  return {};
2838  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2839  int32_t nbOperandsBefore = 0;
2840  for (unsigned i = 0; i < *pos; ++i)
2841  nbOperandsBefore += (*getGangOperandsSegments())[i];
2843  getGangOperands()
2844  .drop_front(nbOperandsBefore)
2845  .take_front((*getGangOperandsSegments())[*pos]);
2846 
2847  int32_t argTypeIdx = nbOperandsBefore;
2848  for (auto value : values) {
2849  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2850  (*getGangOperandsArgType())[argTypeIdx]);
2851  if (gangArgTypeAttr.getValue() == gangArgType)
2852  return value;
2853  ++argTypeIdx;
2854  }
2855  }
2856  return {};
2857 }
2858 
2859 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2860 
2861 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2862  return hasDeviceType(getGang(), deviceType);
2863 }
2864 
2865 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2866  return {&getRegion()};
2867 }
2868 
2869 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2870 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2871 /// `(` ssa-id-and-type-list `)`
2872 /// region
2873 ParseResult
2876  SmallVectorImpl<Type> &lowerboundType,
2878  SmallVectorImpl<Type> &upperboundType,
2880  SmallVectorImpl<Type> &stepType) {
2881 
2882  SmallVector<OpAsmParser::Argument> inductionVars;
2883  if (succeeded(
2884  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2885  if (parser.parseLParen() ||
2886  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2887  /*allowType=*/true) ||
2888  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2889  parser.parseOperandList(lowerbound, inductionVars.size(),
2891  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2892  parser.parseKeyword("to") || parser.parseLParen() ||
2893  parser.parseOperandList(upperbound, inductionVars.size(),
2895  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2896  parser.parseKeyword("step") || parser.parseLParen() ||
2897  parser.parseOperandList(step, inductionVars.size(),
2899  parser.parseColonTypeList(stepType) || parser.parseRParen())
2900  return failure();
2901  }
2902  return parser.parseRegion(region, inductionVars);
2903 }
2904 
2906  ValueRange lowerbound, TypeRange lowerboundType,
2907  ValueRange upperbound, TypeRange upperboundType,
2908  ValueRange steps, TypeRange stepType) {
2909  ValueRange regionArgs = region.front().getArguments();
2910  if (!regionArgs.empty()) {
2911  p << acc::LoopOp::getControlKeyword() << "(";
2912  llvm::interleaveComma(regionArgs, p,
2913  [&p](Value v) { p << v << " : " << v.getType(); });
2914  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2915  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2916  << " : " << stepType << ") ";
2917  }
2918  p.printRegion(region, /*printEntryBlockArgs=*/false);
2919 }
2920 
2921 void acc::LoopOp::addSeq(MLIRContext *context,
2922  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2923  setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2924  effectiveDeviceTypes));
2925 }
2926 
2927 void acc::LoopOp::addIndependent(
2928  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2929  setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2930  context, getIndependentAttr(), effectiveDeviceTypes));
2931 }
2932 
2933 void acc::LoopOp::addAuto(MLIRContext *context,
2934  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2935  setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2936  effectiveDeviceTypes));
2937 }
2938 
2939 void acc::LoopOp::setCollapseForDeviceTypes(
2940  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2941  llvm::APInt value) {
2943  llvm::SmallVector<mlir::Attribute> newDeviceTypes;
2944 
2945  assert((getCollapseAttr() == nullptr) ==
2946  (getCollapseDeviceTypeAttr() == nullptr));
2947  assert(value.getBitWidth() == 64);
2948 
2949  if (getCollapseAttr()) {
2950  for (const auto &existing :
2951  llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2952  newValues.push_back(std::get<0>(existing));
2953  newDeviceTypes.push_back(std::get<1>(existing));
2954  }
2955  }
2956 
2957  if (effectiveDeviceTypes.empty()) {
2958  // If the effective device-types list is empty, this is before there are any
2959  // being applied by device_type, so this should be added as a 'none'.
2960  newValues.push_back(
2961  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2962  newDeviceTypes.push_back(
2964  } else {
2965  for (DeviceType DT : effectiveDeviceTypes) {
2966  newValues.push_back(
2967  mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2968  newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
2969  }
2970  }
2971 
2972  setCollapseAttr(ArrayAttr::get(context, newValues));
2973  setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
2974 }
2975 
2976 void acc::LoopOp::setTileForDeviceTypes(
2977  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2978  ValueRange values) {
2979  llvm::SmallVector<int32_t> segments;
2980  if (getTileOperandsSegments())
2981  llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2982 
2983  setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2984  context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2985  getTileOperandsMutable(), segments));
2986 
2987  setTileOperandsSegments(segments);
2988 }
2989 
2990 void acc::LoopOp::addVectorOperand(
2991  MLIRContext *context, mlir::Value newValue,
2992  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2993  setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2994  context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2995  newValue, getVectorOperandsMutable()));
2996 }
2997 
2998 void acc::LoopOp::addEmptyVector(
2999  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3000  setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3001  effectiveDeviceTypes));
3002 }
3003 
3004 void acc::LoopOp::addWorkerNumOperand(
3005  MLIRContext *context, mlir::Value newValue,
3006  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3007  setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3008  context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3009  newValue, getWorkerNumOperandsMutable()));
3010 }
3011 
3012 void acc::LoopOp::addEmptyWorker(
3013  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3014  setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3015  effectiveDeviceTypes));
3016 }
3017 
3018 void acc::LoopOp::addEmptyGang(
3019  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3020  setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3021  effectiveDeviceTypes));
3022 }
3023 
3024 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3025  auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3026  return attr.getValue() == dt;
3027  };
3028  auto testFromArr = [=](ArrayAttr arr) -> bool {
3029  return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3030  };
3031 
3032  if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3033  return true;
3034  if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3035  return true;
3036  if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3037  return true;
3038 
3039  return false;
3040 }
3041 
3042 bool acc::LoopOp::hasDefaultGangWorkerVector() {
3043  return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3044  hasGang() || getGangValue(GangArgType::Num) ||
3045  getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3046 }
3047 
3048 acc::LoopParMode
3049 acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3050  if (hasSeq(deviceType))
3051  return LoopParMode::loop_seq;
3052  if (hasAuto(deviceType))
3053  return LoopParMode::loop_auto;
3054  if (hasIndependent(deviceType))
3055  return LoopParMode::loop_independent;
3056  if (hasSeq())
3057  return LoopParMode::loop_seq;
3058  if (hasAuto())
3059  return LoopParMode::loop_auto;
3060  assert(hasIndependent() &&
3061  "loop must have default auto, seq, or independent");
3062  return LoopParMode::loop_independent;
3063 }
3064 
3065 void acc::LoopOp::addGangOperands(
3066  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3067  llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
3068  llvm::SmallVector<int32_t> segments;
3069  if (std::optional<ArrayRef<int32_t>> existingSegments =
3070  getGangOperandsSegments())
3071  llvm::copy(*existingSegments, std::back_inserter(segments));
3072 
3073  unsigned beforeCount = segments.size();
3074 
3075  setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3076  context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3077  getGangOperandsMutable(), segments));
3078 
3079  setGangOperandsSegments(segments);
3080 
3081  // This is a bit of extra work to make sure we update the 'types' correctly by
3082  // adding to the types collection the correct number of times. We could
3083  // potentially add something similar to the
3084  // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3085  // excessive for a one-off case.
3086  unsigned numAdded = segments.size() - beforeCount;
3087 
3088  if (numAdded > 0) {
3090  if (getGangOperandsArgTypeAttr())
3091  llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3092 
3093  for (auto i : llvm::index_range(0u, numAdded)) {
3094  llvm::transform(argTypes, std::back_inserter(gangTypes),
3095  [=](mlir::acc::GangArgType gangTy) {
3096  return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3097  });
3098  (void)i;
3099  }
3100 
3101  setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3102  }
3103 }
3104 
3105 void acc::LoopOp::addPrivatization(MLIRContext *context,
3106  mlir::acc::PrivateOp op,
3107  mlir::acc::PrivateRecipeOp recipe) {
3108  getPrivateOperandsMutable().append(op.getResult());
3109 
3111 
3112  if (getPrivatizationRecipesAttr())
3113  llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3114 
3115  recipes.push_back(
3116  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3117  setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3118 }
3119 
3120 void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3121  mlir::acc::ReductionRecipeOp recipe) {
3122  getReductionOperandsMutable().append(op.getResult());
3123 
3125 
3126  if (getReductionRecipesAttr())
3127  llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3128 
3129  recipes.push_back(
3130  mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3131  setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3132 }
3133 
3134 //===----------------------------------------------------------------------===//
3135 // DataOp
3136 //===----------------------------------------------------------------------===//
3137 
3138 LogicalResult acc::DataOp::verify() {
3139  // 2.6.5. Data Construct restriction
3140  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3141  // attach, or default clause must appear on a data construct.
3142  if (getOperands().empty() && !getDefaultAttr())
3143  return emitError("at least one operand or the default attribute "
3144  "must appear on the data operation");
3145 
3146  for (mlir::Value operand : getDataClauseOperands())
3147  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3148  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3149  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3150  operand.getDefiningOp()))
3151  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3152  "as defining op");
3153 
3154  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
3155  return failure();
3156 
3157  return success();
3158 }
3159 
3160 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3161 
3162 Value DataOp::getDataOperand(unsigned i) {
3163  unsigned numOptional = getIfCond() ? 1 : 0;
3164  numOptional += getAsyncOperands().size() ? 1 : 0;
3165  numOptional += getWaitOperands().size();
3166  return getOperand(numOptional + i);
3167 }
3168 
3169 bool acc::DataOp::hasAsyncOnly() {
3170  return hasAsyncOnly(mlir::acc::DeviceType::None);
3171 }
3172 
3173 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3174  return hasDeviceType(getAsyncOnly(), deviceType);
3175 }
3176 
3177 mlir::Value DataOp::getAsyncValue() {
3178  return getAsyncValue(mlir::acc::DeviceType::None);
3179 }
3180 
3181 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3183  getAsyncOperands(), deviceType);
3184 }
3185 
3186 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3187 
3188 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3189  return hasDeviceType(getWaitOnly(), deviceType);
3190 }
3191 
3192 mlir::Operation::operand_range DataOp::getWaitValues() {
3193  return getWaitValues(mlir::acc::DeviceType::None);
3194 }
3195 
3197 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3199  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3200  getHasWaitDevnum(), deviceType);
3201 }
3202 
3203 mlir::Value DataOp::getWaitDevnum() {
3204  return getWaitDevnum(mlir::acc::DeviceType::None);
3205 }
3206 
3207 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3208  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3209  getWaitOperandsSegments(), getHasWaitDevnum(),
3210  deviceType);
3211 }
3212 
3213 void acc::DataOp::addAsyncOnly(
3214  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3215  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3216  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3217 }
3218 
3219 void acc::DataOp::addAsyncOperand(
3220  MLIRContext *context, mlir::Value newValue,
3221  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3222  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3223  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3224  getAsyncOperandsMutable()));
3225 }
3226 
3227 void acc::DataOp::addWaitOnly(MLIRContext *context,
3228  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3229  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3230  effectiveDeviceTypes));
3231 }
3232 
3233 void acc::DataOp::addWaitOperands(
3234  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3235  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3236 
3237  llvm::SmallVector<int32_t> segments;
3238  if (getWaitOperandsSegments())
3239  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3240 
3241  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3242  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3243  getWaitOperandsMutable(), segments));
3244  setWaitOperandsSegments(segments);
3245 
3247  if (getHasWaitDevnumAttr())
3248  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3249  hasDevnums.insert(
3250  hasDevnums.end(),
3251  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3252  mlir::BoolAttr::get(context, hasDevnum));
3253  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3254 }
3255 
3256 //===----------------------------------------------------------------------===//
3257 // ExitDataOp
3258 //===----------------------------------------------------------------------===//
3259 
3260 LogicalResult acc::ExitDataOp::verify() {
3261  // 2.6.6. Data Exit Directive restriction
3262  // At least one copyout, delete, or detach clause must appear on an exit data
3263  // directive.
3264  if (getDataClauseOperands().empty())
3265  return emitError("at least one operand must be present in dataOperands on "
3266  "the exit data operation");
3267 
3268  // The async attribute represent the async clause without value. Therefore the
3269  // attribute and operand cannot appear at the same time.
3270  if (getAsyncOperand() && getAsync())
3271  return emitError("async attribute cannot appear with asyncOperand");
3272 
3273  // The wait attribute represent the wait clause without values. Therefore the
3274  // attribute and operands cannot appear at the same time.
3275  if (!getWaitOperands().empty() && getWait())
3276  return emitError("wait attribute cannot appear with waitOperands");
3277 
3278  if (getWaitDevnum() && getWaitOperands().empty())
3279  return emitError("wait_devnum cannot appear without waitOperands");
3280 
3281  return success();
3282 }
3283 
3284 unsigned ExitDataOp::getNumDataOperands() {
3285  return getDataClauseOperands().size();
3286 }
3287 
3288 Value ExitDataOp::getDataOperand(unsigned i) {
3289  unsigned numOptional = getIfCond() ? 1 : 0;
3290  numOptional += getAsyncOperand() ? 1 : 0;
3291  numOptional += getWaitDevnum() ? 1 : 0;
3292  return getOperand(getWaitOperands().size() + numOptional + i);
3293 }
3294 
3295 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3296  MLIRContext *context) {
3297  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3298 }
3299 
3300 void ExitDataOp::addAsyncOnly(MLIRContext *context,
3301  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3302  assert(effectiveDeviceTypes.empty());
3303  assert(!getAsyncAttr());
3304  assert(!getAsyncOperand());
3305 
3306  setAsyncAttr(mlir::UnitAttr::get(context));
3307 }
3308 
3309 void ExitDataOp::addAsyncOperand(
3310  MLIRContext *context, mlir::Value newValue,
3311  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3312  assert(effectiveDeviceTypes.empty());
3313  assert(!getAsyncAttr());
3314  assert(!getAsyncOperand());
3315 
3316  getAsyncOperandMutable().append(newValue);
3317 }
3318 
3319 void ExitDataOp::addWaitOnly(MLIRContext *context,
3320  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3321  assert(effectiveDeviceTypes.empty());
3322  assert(!getWaitAttr());
3323  assert(getWaitOperands().empty());
3324  assert(!getWaitDevnum());
3325 
3326  setWaitAttr(mlir::UnitAttr::get(context));
3327 }
3328 
3329 void ExitDataOp::addWaitOperands(
3330  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3331  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3332  assert(effectiveDeviceTypes.empty());
3333  assert(!getWaitAttr());
3334  assert(getWaitOperands().empty());
3335  assert(!getWaitDevnum());
3336 
3337  // if hasDevnum, the first value is the devnum. The 'rest' go into the
3338  // operands list.
3339  if (hasDevnum) {
3340  getWaitDevnumMutable().append(newValues.front());
3341  newValues = newValues.drop_front();
3342  }
3343 
3344  getWaitOperandsMutable().append(newValues);
3345 }
3346 
3347 //===----------------------------------------------------------------------===//
3348 // EnterDataOp
3349 //===----------------------------------------------------------------------===//
3350 
3351 LogicalResult acc::EnterDataOp::verify() {
3352  // 2.6.6. Data Enter Directive restriction
3353  // At least one copyin, create, or attach clause must appear on an enter data
3354  // directive.
3355  if (getDataClauseOperands().empty())
3356  return emitError("at least one operand must be present in dataOperands on "
3357  "the enter data operation");
3358 
3359  // The async attribute represent the async clause without value. Therefore the
3360  // attribute and operand cannot appear at the same time.
3361  if (getAsyncOperand() && getAsync())
3362  return emitError("async attribute cannot appear with asyncOperand");
3363 
3364  // The wait attribute represent the wait clause without values. Therefore the
3365  // attribute and operands cannot appear at the same time.
3366  if (!getWaitOperands().empty() && getWait())
3367  return emitError("wait attribute cannot appear with waitOperands");
3368 
3369  if (getWaitDevnum() && getWaitOperands().empty())
3370  return emitError("wait_devnum cannot appear without waitOperands");
3371 
3372  for (mlir::Value operand : getDataClauseOperands())
3373  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3374  operand.getDefiningOp()))
3375  return emitError("expect data entry operation as defining op");
3376 
3377  return success();
3378 }
3379 
3380 unsigned EnterDataOp::getNumDataOperands() {
3381  return getDataClauseOperands().size();
3382 }
3383 
3384 Value EnterDataOp::getDataOperand(unsigned i) {
3385  unsigned numOptional = getIfCond() ? 1 : 0;
3386  numOptional += getAsyncOperand() ? 1 : 0;
3387  numOptional += getWaitDevnum() ? 1 : 0;
3388  return getOperand(getWaitOperands().size() + numOptional + i);
3389 }
3390 
3391 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3392  MLIRContext *context) {
3393  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3394 }
3395 
3396 void EnterDataOp::addAsyncOnly(
3397  MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3398  assert(effectiveDeviceTypes.empty());
3399  assert(!getAsyncAttr());
3400  assert(!getAsyncOperand());
3401 
3402  setAsyncAttr(mlir::UnitAttr::get(context));
3403 }
3404 
3405 void EnterDataOp::addAsyncOperand(
3406  MLIRContext *context, mlir::Value newValue,
3407  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3408  assert(effectiveDeviceTypes.empty());
3409  assert(!getAsyncAttr());
3410  assert(!getAsyncOperand());
3411 
3412  getAsyncOperandMutable().append(newValue);
3413 }
3414 
3415 void EnterDataOp::addWaitOnly(MLIRContext *context,
3416  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3417  assert(effectiveDeviceTypes.empty());
3418  assert(!getWaitAttr());
3419  assert(getWaitOperands().empty());
3420  assert(!getWaitDevnum());
3421 
3422  setWaitAttr(mlir::UnitAttr::get(context));
3423 }
3424 
3425 void EnterDataOp::addWaitOperands(
3426  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3427  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3428  assert(effectiveDeviceTypes.empty());
3429  assert(!getWaitAttr());
3430  assert(getWaitOperands().empty());
3431  assert(!getWaitDevnum());
3432 
3433  // if hasDevnum, the first value is the devnum. The 'rest' go into the
3434  // operands list.
3435  if (hasDevnum) {
3436  getWaitDevnumMutable().append(newValues.front());
3437  newValues = newValues.drop_front();
3438  }
3439 
3440  getWaitOperandsMutable().append(newValues);
3441 }
3442 
3443 //===----------------------------------------------------------------------===//
3444 // AtomicReadOp
3445 //===----------------------------------------------------------------------===//
3446 
3447 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3448 
3449 //===----------------------------------------------------------------------===//
3450 // AtomicWriteOp
3451 //===----------------------------------------------------------------------===//
3452 
3453 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3454 
3455 //===----------------------------------------------------------------------===//
3456 // AtomicUpdateOp
3457 //===----------------------------------------------------------------------===//
3458 
3459 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3460  PatternRewriter &rewriter) {
3461  if (op.isNoOp()) {
3462  rewriter.eraseOp(op);
3463  return success();
3464  }
3465 
3466  if (Value writeVal = op.getWriteOpVal()) {
3467  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3468  return success();
3469  }
3470 
3471  return failure();
3472 }
3473 
3474 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3475 
3476 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3477 
3478 //===----------------------------------------------------------------------===//
3479 // AtomicCaptureOp
3480 //===----------------------------------------------------------------------===//
3481 
3482 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3483  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3484  return op;
3485  return dyn_cast<AtomicReadOp>(getSecondOp());
3486 }
3487 
3488 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3489  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3490  return op;
3491  return dyn_cast<AtomicWriteOp>(getSecondOp());
3492 }
3493 
3494 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3495  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3496  return op;
3497  return dyn_cast<AtomicUpdateOp>(getSecondOp());
3498 }
3499 
3500 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3501 
3502 //===----------------------------------------------------------------------===//
3503 // DeclareEnterOp
3504 //===----------------------------------------------------------------------===//
3505 
3506 template <typename Op>
3507 static LogicalResult
3509  bool requireAtLeastOneOperand = true) {
3510  if (operands.empty() && requireAtLeastOneOperand)
3511  return emitError(
3512  op->getLoc(),
3513  "at least one operand must appear on the declare operation");
3514 
3515  for (mlir::Value operand : operands) {
3516  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3517  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3518  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3519  operand.getDefiningOp()))
3520  return op.emitError(
3521  "expect valid declare data entry operation or acc.getdeviceptr "
3522  "as defining op");
3523 
3524  mlir::Value var{getVar(operand.getDefiningOp())};
3525  assert(var && "declare operands can only be data entry operations which "
3526  "must have var");
3527  (void)var;
3528  std::optional<mlir::acc::DataClause> dataClauseOptional{
3529  getDataClause(operand.getDefiningOp())};
3530  assert(dataClauseOptional.has_value() &&
3531  "declare operands can only be data entry operations which must have "
3532  "dataClause");
3533  (void)dataClauseOptional;
3534  }
3535 
3536  return success();
3537 }
3538 
3539 LogicalResult acc::DeclareEnterOp::verify() {
3540  return checkDeclareOperands(*this, this->getDataClauseOperands());
3541 }
3542 
3543 //===----------------------------------------------------------------------===//
3544 // DeclareExitOp
3545 //===----------------------------------------------------------------------===//
3546 
3547 LogicalResult acc::DeclareExitOp::verify() {
3548  if (getToken())
3549  return checkDeclareOperands(*this, this->getDataClauseOperands(),
3550  /*requireAtLeastOneOperand=*/false);
3551  return checkDeclareOperands(*this, this->getDataClauseOperands());
3552 }
3553 
3554 //===----------------------------------------------------------------------===//
3555 // DeclareOp
3556 //===----------------------------------------------------------------------===//
3557 
3558 LogicalResult acc::DeclareOp::verify() {
3559  return checkDeclareOperands(*this, this->getDataClauseOperands());
3560 }
3561 
3562 //===----------------------------------------------------------------------===//
3563 // RoutineOp
3564 //===----------------------------------------------------------------------===//
3565 
3566 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3567  acc::DeviceType dtype) {
3568  unsigned parallelism = 0;
3569  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3570  parallelism += op.hasWorker(dtype) ? 1 : 0;
3571  parallelism += op.hasVector(dtype) ? 1 : 0;
3572  parallelism += op.hasSeq(dtype) ? 1 : 0;
3573  return parallelism;
3574 }
3575 
3576 LogicalResult acc::RoutineOp::verify() {
3577  unsigned baseParallelism =
3579 
3580  if (baseParallelism > 1)
3581  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3582  "be present at the same time";
3583 
3584  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3585  ++dtypeInt) {
3586  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3587  if (dtype == acc::DeviceType::None)
3588  continue;
3589  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
3590 
3591  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3592  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3593  "be present at the same time";
3594  }
3595 
3596  return success();
3597 }
3598 
3599 static ParseResult parseBindName(OpAsmParser &parser,
3600  mlir::ArrayAttr &bindIdName,
3601  mlir::ArrayAttr &bindStrName,
3602  mlir::ArrayAttr &deviceIdTypes,
3603  mlir::ArrayAttr &deviceStrTypes) {
3604  llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
3605  llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
3606  llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
3607  llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
3608 
3609  if (failed(parser.parseCommaSeparatedList([&]() {
3610  mlir::Attribute newAttr;
3611  bool isSymbolRefAttr;
3612  auto parseResult = parser.parseAttribute(newAttr);
3613  if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3614  bindIdNameAttrs.push_back(symbolRefAttr);
3615  isSymbolRefAttr = true;
3616  } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3617  bindStrNameAttrs.push_back(stringAttr);
3618  isSymbolRefAttr = false;
3619  }
3620  if (parseResult)
3621  return failure();
3622  if (failed(parser.parseOptionalLSquare())) {
3623  if (isSymbolRefAttr) {
3624  deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3625  parser.getContext(), mlir::acc::DeviceType::None));
3626  } else {
3627  deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3628  parser.getContext(), mlir::acc::DeviceType::None));
3629  }
3630  } else {
3631  if (isSymbolRefAttr) {
3632  if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3633  parser.parseRSquare())
3634  return failure();
3635  } else {
3636  if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3637  parser.parseRSquare())
3638  return failure();
3639  }
3640  }
3641  return success();
3642  })))
3643  return failure();
3644 
3645  bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3646  bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3647  deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3648  deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
3649 
3650  return success();
3651 }
3652 
3654  std::optional<mlir::ArrayAttr> bindIdName,
3655  std::optional<mlir::ArrayAttr> bindStrName,
3656  std::optional<mlir::ArrayAttr> deviceIdTypes,
3657  std::optional<mlir::ArrayAttr> deviceStrTypes) {
3658  // Create combined vectors for all bind names and device types
3660  llvm::SmallVector<mlir::Attribute> allDeviceTypes;
3661 
3662  // Append bindIdName and deviceIdTypes
3663  if (hasDeviceTypeValues(deviceIdTypes)) {
3664  allBindNames.append(bindIdName->begin(), bindIdName->end());
3665  allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3666  }
3667 
3668  // Append bindStrName and deviceStrTypes
3669  if (hasDeviceTypeValues(deviceStrTypes)) {
3670  allBindNames.append(bindStrName->begin(), bindStrName->end());
3671  allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3672  }
3673 
3674  // Print the combined sequence
3675  if (!allBindNames.empty())
3676  llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3677  [&](const auto &pair) {
3678  p << std::get<0>(pair);
3679  printSingleDeviceType(p, std::get<1>(pair));
3680  });
3681 }
3682 
3683 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
3684  mlir::ArrayAttr &gang,
3685  mlir::ArrayAttr &gangDim,
3686  mlir::ArrayAttr &gangDimDeviceTypes) {
3687 
3688  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
3689  gangDimDeviceTypeAttrs;
3690  bool needCommaBeforeOperands = false;
3691 
3692  // Gang keyword only
3693  if (failed(parser.parseOptionalLParen())) {
3694  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3696  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3697  return success();
3698  }
3699 
3700  // Parse keyword only attributes
3701  if (succeeded(parser.parseOptionalLSquare())) {
3702  if (failed(parser.parseCommaSeparatedList([&]() {
3703  if (parser.parseAttribute(gangAttrs.emplace_back()))
3704  return failure();
3705  return success();
3706  })))
3707  return failure();
3708  if (parser.parseRSquare())
3709  return failure();
3710  needCommaBeforeOperands = true;
3711  }
3712 
3713  if (needCommaBeforeOperands && failed(parser.parseComma()))
3714  return failure();
3715 
3716  if (failed(parser.parseCommaSeparatedList([&]() {
3717  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3718  parser.parseColon() ||
3719  parser.parseAttribute(gangDimAttrs.emplace_back()))
3720  return failure();
3721  if (succeeded(parser.parseOptionalLSquare())) {
3722  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3723  parser.parseRSquare())
3724  return failure();
3725  } else {
3726  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3727  parser.getContext(), mlir::acc::DeviceType::None));
3728  }
3729  return success();
3730  })))
3731  return failure();
3732 
3733  if (failed(parser.parseRParen()))
3734  return failure();
3735 
3736  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3737  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
3738  gangDimDeviceTypes =
3739  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
3740 
3741  return success();
3742 }
3743 
3745  std::optional<mlir::ArrayAttr> gang,
3746  std::optional<mlir::ArrayAttr> gangDim,
3747  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3748 
3749  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
3750  gang->size() == 1) {
3751  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3752  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3753  return;
3754  }
3755 
3756  p << "(";
3757 
3758  printDeviceTypes(p, gang);
3759 
3760  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
3761  p << ", ";
3762 
3763  if (hasDeviceTypeValues(gangDimDeviceTypes))
3764  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3765  [&](const auto &pair) {
3766  p << acc::RoutineOp::getGangDimKeyword() << ": ";
3767  p << std::get<0>(pair);
3768  printSingleDeviceType(p, std::get<1>(pair));
3769  });
3770 
3771  p << ")";
3772 }
3773 
3774 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
3775  mlir::ArrayAttr &deviceTypes) {
3777  // Keyword only
3778  if (failed(parser.parseOptionalLParen())) {
3779  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
3781  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3782  return success();
3783  }
3784 
3785  // Parse device type attributes
3786  if (succeeded(parser.parseOptionalLSquare())) {
3787  if (failed(parser.parseCommaSeparatedList([&]() {
3788  if (parser.parseAttribute(attributes.emplace_back()))
3789  return failure();
3790  return success();
3791  })))
3792  return failure();
3793  if (parser.parseRSquare() || parser.parseRParen())
3794  return failure();
3795  }
3796  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3797  return success();
3798 }
3799 
3800 static void
3802  std::optional<mlir::ArrayAttr> deviceTypes) {
3803 
3804  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
3805  auto deviceTypeAttr =
3806  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3807  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3808  return;
3809  }
3810 
3811  if (!hasDeviceTypeValues(deviceTypes))
3812  return;
3813 
3814  p << "([";
3815  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
3816  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3817  p << dTypeAttr;
3818  });
3819  p << "])";
3820 }
3821 
3822 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3823 
3824 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3825  return hasDeviceType(getWorker(), deviceType);
3826 }
3827 
3828 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3829 
3830 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3831  return hasDeviceType(getVector(), deviceType);
3832 }
3833 
3834 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3835 
3836 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3837  return hasDeviceType(getSeq(), deviceType);
3838 }
3839 
3840 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3841 RoutineOp::getBindNameValue() {
3842  return getBindNameValue(mlir::acc::DeviceType::None);
3843 }
3844 
3845 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3846 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3847  if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
3848  !hasDeviceTypeValues(getBindStrNameDeviceType())) {
3849  return std::nullopt;
3850  }
3851 
3852  if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
3853  auto attr = (*getBindIdName())[*pos];
3854  auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3855  assert(symbolRefAttr && "expected SymbolRef");
3856  return symbolRefAttr;
3857  }
3858 
3859  if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
3860  auto attr = (*getBindStrName())[*pos];
3861  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3862  assert(stringAttr && "expected String");
3863  return stringAttr;
3864  }
3865 
3866  return std::nullopt;
3867 }
3868 
3869 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3870 
3871 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3872  return hasDeviceType(getGang(), deviceType);
3873 }
3874 
3875 std::optional<int64_t> RoutineOp::getGangDimValue() {
3876  return getGangDimValue(mlir::acc::DeviceType::None);
3877 }
3878 
3879 std::optional<int64_t>
3880 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3881  if (!hasDeviceTypeValues(getGangDimDeviceType()))
3882  return std::nullopt;
3883  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
3884  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3885  return intAttr.getInt();
3886  }
3887  return std::nullopt;
3888 }
3889 
3890 //===----------------------------------------------------------------------===//
3891 // InitOp
3892 //===----------------------------------------------------------------------===//
3893 
3894 LogicalResult acc::InitOp::verify() {
3895  Operation *currOp = *this;
3896  while ((currOp = currOp->getParentOp()))
3897  if (isComputeOperation(currOp))
3898  return emitOpError("cannot be nested in a compute operation");
3899  return success();
3900 }
3901 
3902 void acc::InitOp::addDeviceType(MLIRContext *context,
3903  mlir::acc::DeviceType deviceType) {
3905  if (getDeviceTypesAttr())
3906  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3907 
3908  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3909  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3910 }
3911 
3912 //===----------------------------------------------------------------------===//
3913 // ShutdownOp
3914 //===----------------------------------------------------------------------===//
3915 
3916 LogicalResult acc::ShutdownOp::verify() {
3917  Operation *currOp = *this;
3918  while ((currOp = currOp->getParentOp()))
3919  if (isComputeOperation(currOp))
3920  return emitOpError("cannot be nested in a compute operation");
3921  return success();
3922 }
3923 
3924 void acc::ShutdownOp::addDeviceType(MLIRContext *context,
3925  mlir::acc::DeviceType deviceType) {
3927  if (getDeviceTypesAttr())
3928  llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3929 
3930  deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3931  setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3932 }
3933 
3934 //===----------------------------------------------------------------------===//
3935 // SetOp
3936 //===----------------------------------------------------------------------===//
3937 
3938 LogicalResult acc::SetOp::verify() {
3939  Operation *currOp = *this;
3940  while ((currOp = currOp->getParentOp()))
3941  if (isComputeOperation(currOp))
3942  return emitOpError("cannot be nested in a compute operation");
3943  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3944  return emitOpError("at least one default_async, device_num, or device_type "
3945  "operand must appear");
3946  return success();
3947 }
3948 
3949 //===----------------------------------------------------------------------===//
3950 // UpdateOp
3951 //===----------------------------------------------------------------------===//
3952 
3953 LogicalResult acc::UpdateOp::verify() {
3954  // At least one of host or device should have a value.
3955  if (getDataClauseOperands().empty())
3956  return emitError("at least one value must be present in dataOperands");
3957 
3959  getAsyncOperandsDeviceTypeAttr(),
3960  "async")))
3961  return failure();
3962 
3964  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3965  getWaitOperandsDeviceTypeAttr(), "wait")))
3966  return failure();
3967 
3968  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
3969  return failure();
3970 
3971  for (mlir::Value operand : getDataClauseOperands())
3972  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3973  operand.getDefiningOp()))
3974  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3975  "as defining op");
3976 
3977  return success();
3978 }
3979 
3980 unsigned UpdateOp::getNumDataOperands() {
3981  return getDataClauseOperands().size();
3982 }
3983 
3984 Value UpdateOp::getDataOperand(unsigned i) {
3985  unsigned numOptional = getAsyncOperands().size();
3986  numOptional += getIfCond() ? 1 : 0;
3987  return getOperand(getWaitOperands().size() + numOptional + i);
3988 }
3989 
3990 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3991  MLIRContext *context) {
3992  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
3993 }
3994 
3995 bool UpdateOp::hasAsyncOnly() {
3996  return hasAsyncOnly(mlir::acc::DeviceType::None);
3997 }
3998 
3999 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4000  return hasDeviceType(getAsyncOnly(), deviceType);
4001 }
4002 
4003 mlir::Value UpdateOp::getAsyncValue() {
4004  return getAsyncValue(mlir::acc::DeviceType::None);
4005 }
4006 
4007 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4009  return {};
4010 
4011  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4012  return getAsyncOperands()[*pos];
4013 
4014  return {};
4015 }
4016 
4017 bool UpdateOp::hasWaitOnly() {
4018  return hasWaitOnly(mlir::acc::DeviceType::None);
4019 }
4020 
4021 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4022  return hasDeviceType(getWaitOnly(), deviceType);
4023 }
4024 
4025 mlir::Operation::operand_range UpdateOp::getWaitValues() {
4026  return getWaitValues(mlir::acc::DeviceType::None);
4027 }
4028 
4030 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4032  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4033  getHasWaitDevnum(), deviceType);
4034 }
4035 
4036 mlir::Value UpdateOp::getWaitDevnum() {
4037  return getWaitDevnum(mlir::acc::DeviceType::None);
4038 }
4039 
4040 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4041  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4042  getWaitOperandsSegments(), getHasWaitDevnum(),
4043  deviceType);
4044 }
4045 
4046 void UpdateOp::addAsyncOnly(MLIRContext *context,
4047  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4048  setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4049  context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4050 }
4051 
4052 void UpdateOp::addAsyncOperand(
4053  MLIRContext *context, mlir::Value newValue,
4054  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4055  setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4056  context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4057  getAsyncOperandsMutable()));
4058 }
4059 
4060 void UpdateOp::addWaitOnly(MLIRContext *context,
4061  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4062  setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4063  effectiveDeviceTypes));
4064 }
4065 
4066 void UpdateOp::addWaitOperands(
4067  MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4068  llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4069 
4070  llvm::SmallVector<int32_t> segments;
4071  if (getWaitOperandsSegments())
4072  llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4073 
4074  setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4075  context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4076  getWaitOperandsMutable(), segments));
4077  setWaitOperandsSegments(segments);
4078 
4080  if (getHasWaitDevnumAttr())
4081  llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4082  hasDevnums.insert(
4083  hasDevnums.end(),
4084  std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4085  mlir::BoolAttr::get(context, hasDevnum));
4086  setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4087 }
4088 
4089 //===----------------------------------------------------------------------===//
4090 // WaitOp
4091 //===----------------------------------------------------------------------===//
4092 
4093 LogicalResult acc::WaitOp::verify() {
4094  // The async attribute represent the async clause without value. Therefore the
4095  // attribute and operand cannot appear at the same time.
4096  if (getAsyncOperand() && getAsync())
4097  return emitError("async attribute cannot appear with asyncOperand");
4098 
4099  if (getWaitDevnum() && getWaitOperands().empty())
4100  return emitError("wait_devnum cannot appear without waitOperands");
4101 
4102  return success();
4103 }
4104 
4105 #define GET_OP_CLASSES
4106 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4107 
4108 #define GET_ATTRDEF_CLASSES
4109 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4110 
4111 #define GET_TYPEDEF_CLASSES
4112 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4113 
4114 //===----------------------------------------------------------------------===//
4115 // acc dialect utilities
4116 //===----------------------------------------------------------------------===//
4117 
4120  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4122  accDataClauseOp)
4123  .Case<ACC_DATA_ENTRY_OPS>(
4124  [&](auto entry) { return entry.getVarPtr(); })
4125  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4126  [&](auto exit) { return exit.getVarPtr(); })
4127  .Default([&](mlir::Operation *) {
4129  })};
4130  return varPtr;
4131 }
4132 
4134  auto varPtr{
4136  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4137  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4138  return varPtr;
4139 }
4140 
4142  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4143  .Case<ACC_DATA_ENTRY_OPS>(
4144  [&](auto entry) { return entry.getVarType(); })
4145  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4146  [&](auto exit) { return exit.getVarType(); })
4147  .Default([&](mlir::Operation *) { return mlir::Type(); })};
4148  return varType;
4149 }
4150 
4153  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4155  accDataClauseOp)
4156  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4157  [&](auto dataClause) { return dataClause.getAccPtr(); })
4158  .Default([&](mlir::Operation *) {
4160  })};
4161  return accPtr;
4162 }
4163 
4165  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4167  [&](auto dataClause) { return dataClause.getAccVar(); })
4168  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4169  return accPtr;
4170 }
4171 
4173  auto varPtrPtr{
4175  .Case<ACC_DATA_ENTRY_OPS>(
4176  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4177  .Default([&](mlir::Operation *) { return mlir::Value(); })};
4178  return varPtrPtr;
4179 }
4180 
4185  accDataClauseOp)
4186  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4188  dataClause.getBounds().begin(), dataClause.getBounds().end());
4189  })
4190  .Default([&](mlir::Operation *) {
4192  })};
4193  return bounds;
4194 }
4195 
4199  accDataClauseOp)
4200  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4202  dataClause.getAsyncOperands().begin(),
4203  dataClause.getAsyncOperands().end());
4204  })
4205  .Default([&](mlir::Operation *) {
4207  });
4208 }
4209 
4210 mlir::ArrayAttr
4213  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4214  return dataClause.getAsyncOperandsDeviceTypeAttr();
4215  })
4216  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4217 }
4218 
4219 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4222  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4223  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4224 }
4225 
4226 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4227  auto name{
4229  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4230  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4231  return {};
4232  })};
4233  return name;
4234 }
4235 
4236 std::optional<mlir::acc::DataClause>
4238  auto dataClause{
4240  accDataEntryOp)
4241  .Case<ACC_DATA_ENTRY_OPS>(
4242  [&](auto entry) { return entry.getDataClause(); })
4243  .Default([&](mlir::Operation *) { return std::nullopt; })};
4244  return dataClause;
4245 }
4246 
4248  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4249  .Case<ACC_DATA_ENTRY_OPS>(
4250  [&](auto entry) { return entry.getImplicit(); })
4251  .Default([&](mlir::Operation *) { return false; })};
4252  return implicit;
4253 }
4254 
4256  auto dataOperands{
4259  [&](auto entry) { return entry.getDataClauseOperands(); })
4260  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4261  return dataOperands;
4262 }
4263 
4266  auto dataOperands{
4269  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4270  .Default([&](mlir::Operation *) { return nullptr; })};
4271  return dataOperands;
4272 }
4273 
4275  mlir::Operation *parentOp = region.getParentOp();
4276  while (parentOp) {
4277  if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
4278  return parentOp;
4279  }
4280  parentOp = parentOp->getParentOp();
4281  }
4282  return nullptr;
4283 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2366
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:3744
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:811
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:2515
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:1113
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition: OpenACC.cpp:3599
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2529
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:825
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1635
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:380
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:349
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1646
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1551
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:175
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:3801
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:2324
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1886
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:3508
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:309
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1840
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:193
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:291
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition: OpenACC.cpp:325
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2874
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:1042
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1682
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1195
static LogicalResult checkNoModifier(Op op)
Definition: OpenACC.cpp:317
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:358
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:217
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1421
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:334
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2905
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:3774
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:3683
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1534
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1709
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1825
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1488
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:1026
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:249
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition: OpenACC.cpp:1801
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:421
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:2343
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:901
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition: OpenACC.cpp:1870
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1465
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:204
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:1057
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1782
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:179
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:2470
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:233
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1720
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:392
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:269
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:1123
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:3566
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1471
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1906
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:1006
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition: OpenACC.cpp:3653
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:53
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:205
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:830
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:836
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:4164
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:4133
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:4152
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:4237
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:4265
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:4182
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:4255
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:4226
mlir::Operation * getEnclosingComputeOp(mlir::Region &region)
Used to obtain the enclosing compute construct operation that contains the provided region.
Definition: OpenACC.cpp:4274
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:4247
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:4197
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:4172
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:4219
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:4141
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:4119
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:4211
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.