MLIR  21.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 
25 using namespace mlir;
26 using namespace acc;
27 
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33 
34 namespace {
35 
36 static bool isScalarLikeType(Type type) {
37  return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
38 }
39 
40 struct MemRefPointerLikeModel
41  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
42  MemRefType> {
43  Type getElementType(Type pointer) const {
44  return cast<MemRefType>(pointer).getElementType();
45  }
46  mlir::acc::VariableTypeCategory
47  getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48  Type varType) const {
49  if (auto mappableTy = dyn_cast<MappableType>(varType)) {
50  return mappableTy.getTypeCategory(varPtr);
51  }
52  auto memrefTy = cast<MemRefType>(pointer);
53  if (!memrefTy.hasRank()) {
54  // This memref is unranked - aka it could have any rank, including a
55  // rank of 0 which could mean scalar. For now, return uncategorized.
56  return mlir::acc::VariableTypeCategory::uncategorized;
57  }
58 
59  if (memrefTy.getRank() == 0) {
60  if (isScalarLikeType(memrefTy.getElementType())) {
61  return mlir::acc::VariableTypeCategory::scalar;
62  }
63  // Zero-rank non-scalar - need further analysis to determine the type
64  // category. For now, return uncategorized.
65  return mlir::acc::VariableTypeCategory::uncategorized;
66  }
67 
68  // It has a rank - must be an array.
69  assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70  return mlir::acc::VariableTypeCategory::array;
71  }
72 };
73 
74 struct LLVMPointerPointerLikeModel
75  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76  LLVM::LLVMPointerType> {
77  Type getElementType(Type pointer) const { return Type(); }
78 };
79 } // namespace
80 
81 //===----------------------------------------------------------------------===//
82 // OpenACC operations
83 //===----------------------------------------------------------------------===//
84 
85 void OpenACCDialect::initialize() {
86  addOperations<
87 #define GET_OP_LIST
88 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
89  >();
90  addAttributes<
91 #define GET_ATTRDEF_LIST
92 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
93  >();
94  addTypes<
95 #define GET_TYPEDEF_LIST
96 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
97  >();
98 
99  // By attaching interfaces here, we make the OpenACC dialect dependent on
100  // the other dialects. This is probably better than having dialects like LLVM
101  // and memref be dependent on OpenACC.
102  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
103  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
104  *getContext());
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // device_type support helpers
109 //===----------------------------------------------------------------------===//
110 
111 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
112  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
113  return true;
114  return false;
115 }
116 
117 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
118  mlir::acc::DeviceType deviceType) {
119  if (!hasDeviceTypeValues(arrayAttr))
120  return false;
121 
122  for (auto attr : *arrayAttr) {
123  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
124  if (deviceTypeAttr.getValue() == deviceType)
125  return true;
126  }
127 
128  return false;
129 }
130 
132  std::optional<mlir::ArrayAttr> deviceTypes) {
133  if (!hasDeviceTypeValues(deviceTypes))
134  return;
135 
136  p << "[";
137  llvm::interleaveComma(*deviceTypes, p,
138  [&](mlir::Attribute attr) { p << attr; });
139  p << "]";
140 }
141 
142 static std::optional<unsigned> findSegment(ArrayAttr segments,
143  mlir::acc::DeviceType deviceType) {
144  unsigned segmentIdx = 0;
145  for (auto attr : segments) {
146  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
147  if (deviceTypeAttr.getValue() == deviceType)
148  return std::make_optional(segmentIdx);
149  ++segmentIdx;
150  }
151  return std::nullopt;
152 }
153 
155 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
157  std::optional<llvm::ArrayRef<int32_t>> segments,
158  mlir::acc::DeviceType deviceType) {
159  if (!arrayAttr)
160  return range.take_front(0);
161  if (auto pos = findSegment(*arrayAttr, deviceType)) {
162  int32_t nbOperandsBefore = 0;
163  for (unsigned i = 0; i < *pos; ++i)
164  nbOperandsBefore += (*segments)[i];
165  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
166  }
167  return range.take_front(0);
168 }
169 
170 static mlir::Value
171 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
173  std::optional<llvm::ArrayRef<int32_t>> segments,
174  std::optional<mlir::ArrayAttr> hasWaitDevnum,
175  mlir::acc::DeviceType deviceType) {
176  if (!hasDeviceTypeValues(deviceTypeAttr))
177  return {};
178  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
179  if (hasWaitDevnum->getValue()[*pos])
180  return getValuesFromSegments(deviceTypeAttr, operands, segments,
181  deviceType)
182  .front();
183  return {};
184 }
185 
187 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
189  std::optional<llvm::ArrayRef<int32_t>> segments,
190  std::optional<mlir::ArrayAttr> hasWaitDevnum,
191  mlir::acc::DeviceType deviceType) {
192  auto range =
193  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
194  if (range.empty())
195  return range;
196  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
197  if (hasWaitDevnum && *hasWaitDevnum) {
198  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
199  if (boolAttr.getValue())
200  return range.drop_front(1); // first value is devnum
201  }
202  }
203  return range;
204 }
205 
206 template <typename Op>
207 static LogicalResult checkWaitAndAsyncConflict(Op op) {
208  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
209  ++dtypeInt) {
210  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
211 
212  // The async attribute represent the async clause without value. Therefore
213  // the attribute and operand cannot appear at the same time.
214  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
215  op.hasAsyncOnly(dtype))
216  return op.emitError("async attribute cannot appear with asyncOperand");
217 
218  // The wait attribute represent the wait clause without values. Therefore
219  // the attribute and operands cannot appear at the same time.
220  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
221  op.hasWaitOnly(dtype))
222  return op.emitError("wait attribute cannot appear with waitOperands");
223  }
224  return success();
225 }
226 
227 template <typename Op>
228 static LogicalResult checkVarAndVarType(Op op) {
229  if (!op.getVar())
230  return op.emitError("must have var operand");
231 
232  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
233  mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
234  // TODO: If a type implements both interfaces (mappable and pointer-like),
235  // it is unclear which semantics to apply without additional info which
236  // would need captured in the data operation. For now restrict this case
237  // unless a compelling reason to support disambiguating between the two.
238  return op.emitError("var must be mappable or pointer-like (not both)");
239  }
240 
241  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
242  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
243  return op.emitError("var must be mappable or pointer-like");
244 
245  if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
246  op.getVarType() != op.getVar().getType())
247  return op.emitError("varType must match when var is mappable");
248 
249  return success();
250 }
251 
252 template <typename Op>
253 static LogicalResult checkVarAndAccVar(Op op) {
254  if (op.getVar().getType() != op.getAccVar().getType())
255  return op.emitError("input and output types must match");
256 
257  return success();
258 }
259 
260 static ParseResult parseVar(mlir::OpAsmParser &parser,
262  // Either `var` or `varPtr` keyword is required.
263  if (failed(parser.parseOptionalKeyword("varPtr"))) {
264  if (failed(parser.parseKeyword("var")))
265  return failure();
266  }
267  if (failed(parser.parseLParen()))
268  return failure();
269  if (failed(parser.parseOperand(var)))
270  return failure();
271 
272  return success();
273 }
274 
276  mlir::Value var) {
277  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
278  p << "varPtr(";
279  else
280  p << "var(";
281  p.printOperand(var);
282 }
283 
284 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
286  mlir::Type &accVarType) {
287  // Either `accVar` or `accPtr` keyword is required.
288  if (failed(parser.parseOptionalKeyword("accPtr"))) {
289  if (failed(parser.parseKeyword("accVar")))
290  return failure();
291  }
292  if (failed(parser.parseLParen()))
293  return failure();
294  if (failed(parser.parseOperand(var)))
295  return failure();
296  if (failed(parser.parseColon()))
297  return failure();
298  if (failed(parser.parseType(accVarType)))
299  return failure();
300  if (failed(parser.parseRParen()))
301  return failure();
302 
303  return success();
304 }
305 
307  mlir::Value accVar, mlir::Type accVarType) {
308  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
309  p << "accPtr(";
310  else
311  p << "accVar(";
312  p.printOperand(accVar);
313  p << " : ";
314  p.printType(accVarType);
315  p << ")";
316 }
317 
318 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
319  mlir::Type &varPtrType,
320  mlir::TypeAttr &varTypeAttr) {
321  if (failed(parser.parseType(varPtrType)))
322  return failure();
323  if (failed(parser.parseRParen()))
324  return failure();
325 
326  if (succeeded(parser.parseOptionalKeyword("varType"))) {
327  if (failed(parser.parseLParen()))
328  return failure();
329  mlir::Type varType;
330  if (failed(parser.parseType(varType)))
331  return failure();
332  varTypeAttr = mlir::TypeAttr::get(varType);
333  if (failed(parser.parseRParen()))
334  return failure();
335  } else {
336  // Set `varType` from the element type of the type of `varPtr`.
337  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
338  varTypeAttr = mlir::TypeAttr::get(
339  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
340  else
341  varTypeAttr = mlir::TypeAttr::get(varPtrType);
342  }
343 
344  return success();
345 }
346 
348  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
349  p.printType(varPtrType);
350  p << ")";
351 
352  // Print the `varType` only if it differs from the element type of
353  // `varPtr`'s type.
354  mlir::Type varType = varTypeAttr.getValue();
355  mlir::Type typeToCheckAgainst =
356  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
357  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
358  : varPtrType;
359  if (typeToCheckAgainst != varType) {
360  p << " varType(";
361  p.printType(varType);
362  p << ")";
363  }
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // DataBoundsOp
368 //===----------------------------------------------------------------------===//
369 LogicalResult acc::DataBoundsOp::verify() {
370  auto extent = getExtent();
371  auto upperbound = getUpperbound();
372  if (!extent && !upperbound)
373  return emitError("expected extent or upperbound.");
374  return success();
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // PrivateOp
379 //===----------------------------------------------------------------------===//
380 LogicalResult acc::PrivateOp::verify() {
381  if (getDataClause() != acc::DataClause::acc_private)
382  return emitError(
383  "data clause associated with private operation must match its intent");
384  if (failed(checkVarAndVarType(*this)))
385  return failure();
386  return success();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // FirstprivateOp
391 //===----------------------------------------------------------------------===//
392 LogicalResult acc::FirstprivateOp::verify() {
393  if (getDataClause() != acc::DataClause::acc_firstprivate)
394  return emitError("data clause associated with firstprivate operation must "
395  "match its intent");
396  if (failed(checkVarAndVarType(*this)))
397  return failure();
398  return success();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // ReductionOp
403 //===----------------------------------------------------------------------===//
404 LogicalResult acc::ReductionOp::verify() {
405  if (getDataClause() != acc::DataClause::acc_reduction)
406  return emitError("data clause associated with reduction operation must "
407  "match its intent");
408  if (failed(checkVarAndVarType(*this)))
409  return failure();
410  return success();
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // DevicePtrOp
415 //===----------------------------------------------------------------------===//
416 LogicalResult acc::DevicePtrOp::verify() {
417  if (getDataClause() != acc::DataClause::acc_deviceptr)
418  return emitError("data clause associated with deviceptr operation must "
419  "match its intent");
420  if (failed(checkVarAndVarType(*this)))
421  return failure();
422  if (failed(checkVarAndAccVar(*this)))
423  return failure();
424  return success();
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // PresentOp
429 //===----------------------------------------------------------------------===//
430 LogicalResult acc::PresentOp::verify() {
431  if (getDataClause() != acc::DataClause::acc_present)
432  return emitError(
433  "data clause associated with present operation must match its intent");
434  if (failed(checkVarAndVarType(*this)))
435  return failure();
436  if (failed(checkVarAndAccVar(*this)))
437  return failure();
438  return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // CopyinOp
443 //===----------------------------------------------------------------------===//
444 LogicalResult acc::CopyinOp::verify() {
445  // Test for all clauses this operation can be decomposed from:
446  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
447  getDataClause() != acc::DataClause::acc_copyin_readonly &&
448  getDataClause() != acc::DataClause::acc_copy &&
449  getDataClause() != acc::DataClause::acc_reduction)
450  return emitError(
451  "data clause associated with copyin operation must match its intent"
452  " or specify original clause this operation was decomposed from");
453  if (failed(checkVarAndVarType(*this)))
454  return failure();
455  if (failed(checkVarAndAccVar(*this)))
456  return failure();
457  return success();
458 }
459 
460 bool acc::CopyinOp::isCopyinReadonly() {
461  return getDataClause() == acc::DataClause::acc_copyin_readonly;
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // CreateOp
466 //===----------------------------------------------------------------------===//
467 LogicalResult acc::CreateOp::verify() {
468  // Test for all clauses this operation can be decomposed from:
469  if (getDataClause() != acc::DataClause::acc_create &&
470  getDataClause() != acc::DataClause::acc_create_zero &&
471  getDataClause() != acc::DataClause::acc_copyout &&
472  getDataClause() != acc::DataClause::acc_copyout_zero)
473  return emitError(
474  "data clause associated with create operation must match its intent"
475  " or specify original clause this operation was decomposed from");
476  if (failed(checkVarAndVarType(*this)))
477  return failure();
478  if (failed(checkVarAndAccVar(*this)))
479  return failure();
480  return success();
481 }
482 
483 bool acc::CreateOp::isCreateZero() {
484  // The zero modifier is encoded in the data clause.
485  return getDataClause() == acc::DataClause::acc_create_zero ||
486  getDataClause() == acc::DataClause::acc_copyout_zero;
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // NoCreateOp
491 //===----------------------------------------------------------------------===//
492 LogicalResult acc::NoCreateOp::verify() {
493  if (getDataClause() != acc::DataClause::acc_no_create)
494  return emitError("data clause associated with no_create operation must "
495  "match its intent");
496  if (failed(checkVarAndVarType(*this)))
497  return failure();
498  if (failed(checkVarAndAccVar(*this)))
499  return failure();
500  return success();
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // AttachOp
505 //===----------------------------------------------------------------------===//
506 LogicalResult acc::AttachOp::verify() {
507  if (getDataClause() != acc::DataClause::acc_attach)
508  return emitError(
509  "data clause associated with attach operation must match its intent");
510  if (failed(checkVarAndVarType(*this)))
511  return failure();
512  if (failed(checkVarAndAccVar(*this)))
513  return failure();
514  return success();
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // DeclareDeviceResidentOp
519 //===----------------------------------------------------------------------===//
520 
521 LogicalResult acc::DeclareDeviceResidentOp::verify() {
522  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
523  return emitError("data clause associated with device_resident operation "
524  "must match its intent");
525  if (failed(checkVarAndVarType(*this)))
526  return failure();
527  if (failed(checkVarAndAccVar(*this)))
528  return failure();
529  return success();
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // DeclareLinkOp
534 //===----------------------------------------------------------------------===//
535 
536 LogicalResult acc::DeclareLinkOp::verify() {
537  if (getDataClause() != acc::DataClause::acc_declare_link)
538  return emitError(
539  "data clause associated with link operation must match its intent");
540  if (failed(checkVarAndVarType(*this)))
541  return failure();
542  if (failed(checkVarAndAccVar(*this)))
543  return failure();
544  return success();
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // CopyoutOp
549 //===----------------------------------------------------------------------===//
550 LogicalResult acc::CopyoutOp::verify() {
551  // Test for all clauses this operation can be decomposed from:
552  if (getDataClause() != acc::DataClause::acc_copyout &&
553  getDataClause() != acc::DataClause::acc_copyout_zero &&
554  getDataClause() != acc::DataClause::acc_copy &&
555  getDataClause() != acc::DataClause::acc_reduction)
556  return emitError(
557  "data clause associated with copyout operation must match its intent"
558  " or specify original clause this operation was decomposed from");
559  if (!getVar() || !getAccVar())
560  return emitError("must have both host and device pointers");
561  if (failed(checkVarAndVarType(*this)))
562  return failure();
563  if (failed(checkVarAndAccVar(*this)))
564  return failure();
565  return success();
566 }
567 
568 bool acc::CopyoutOp::isCopyoutZero() {
569  return getDataClause() == acc::DataClause::acc_copyout_zero;
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // DeleteOp
574 //===----------------------------------------------------------------------===//
575 LogicalResult acc::DeleteOp::verify() {
576  // Test for all clauses this operation can be decomposed from:
577  if (getDataClause() != acc::DataClause::acc_delete &&
578  getDataClause() != acc::DataClause::acc_create &&
579  getDataClause() != acc::DataClause::acc_create_zero &&
580  getDataClause() != acc::DataClause::acc_copyin &&
581  getDataClause() != acc::DataClause::acc_copyin_readonly &&
582  getDataClause() != acc::DataClause::acc_present &&
583  getDataClause() != acc::DataClause::acc_no_create &&
584  getDataClause() != acc::DataClause::acc_declare_device_resident &&
585  getDataClause() != acc::DataClause::acc_declare_link)
586  return emitError(
587  "data clause associated with delete operation must match its intent"
588  " or specify original clause this operation was decomposed from");
589  if (!getAccVar())
590  return emitError("must have device pointer");
591  return success();
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // DetachOp
596 //===----------------------------------------------------------------------===//
597 LogicalResult acc::DetachOp::verify() {
598  // Test for all clauses this operation can be decomposed from:
599  if (getDataClause() != acc::DataClause::acc_detach &&
600  getDataClause() != acc::DataClause::acc_attach)
601  return emitError(
602  "data clause associated with detach operation must match its intent"
603  " or specify original clause this operation was decomposed from");
604  if (!getAccVar())
605  return emitError("must have device pointer");
606  return success();
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // HostOp
611 //===----------------------------------------------------------------------===//
612 LogicalResult acc::UpdateHostOp::verify() {
613  // Test for all clauses this operation can be decomposed from:
614  if (getDataClause() != acc::DataClause::acc_update_host &&
615  getDataClause() != acc::DataClause::acc_update_self)
616  return emitError(
617  "data clause associated with host operation must match its intent"
618  " or specify original clause this operation was decomposed from");
619  if (!getVar() || !getAccVar())
620  return emitError("must have both host and device pointers");
621  if (failed(checkVarAndVarType(*this)))
622  return failure();
623  if (failed(checkVarAndAccVar(*this)))
624  return failure();
625  return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // DeviceOp
630 //===----------------------------------------------------------------------===//
631 LogicalResult acc::UpdateDeviceOp::verify() {
632  // Test for all clauses this operation can be decomposed from:
633  if (getDataClause() != acc::DataClause::acc_update_device)
634  return emitError(
635  "data clause associated with device operation must match its intent"
636  " or specify original clause this operation was decomposed from");
637  if (failed(checkVarAndVarType(*this)))
638  return failure();
639  if (failed(checkVarAndAccVar(*this)))
640  return failure();
641  return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // UseDeviceOp
646 //===----------------------------------------------------------------------===//
647 LogicalResult acc::UseDeviceOp::verify() {
648  // Test for all clauses this operation can be decomposed from:
649  if (getDataClause() != acc::DataClause::acc_use_device)
650  return emitError(
651  "data clause associated with use_device operation must match its intent"
652  " or specify original clause this operation was decomposed from");
653  if (failed(checkVarAndVarType(*this)))
654  return failure();
655  if (failed(checkVarAndAccVar(*this)))
656  return failure();
657  return success();
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // CacheOp
662 //===----------------------------------------------------------------------===//
663 LogicalResult acc::CacheOp::verify() {
664  // Test for all clauses this operation can be decomposed from:
665  if (getDataClause() != acc::DataClause::acc_cache &&
666  getDataClause() != acc::DataClause::acc_cache_readonly)
667  return emitError(
668  "data clause associated with cache operation must match its intent"
669  " or specify original clause this operation was decomposed from");
670  if (failed(checkVarAndVarType(*this)))
671  return failure();
672  if (failed(checkVarAndAccVar(*this)))
673  return failure();
674  return success();
675 }
676 
677 template <typename StructureOp>
678 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
679  unsigned nRegions = 1) {
680 
681  SmallVector<Region *, 2> regions;
682  for (unsigned i = 0; i < nRegions; ++i)
683  regions.push_back(state.addRegion());
684 
685  for (Region *region : regions)
686  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
687  return failure();
688 
689  return success();
690 }
691 
692 static bool isComputeOperation(Operation *op) {
693  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
694 }
695 
696 namespace {
697 /// Pattern to remove operation without region that have constant false `ifCond`
698 /// and remove the condition from the operation if the `ifCond` is a true
699 /// constant.
700 template <typename OpTy>
701 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
703 
704  LogicalResult matchAndRewrite(OpTy op,
705  PatternRewriter &rewriter) const override {
706  // Early return if there is no condition.
707  Value ifCond = op.getIfCond();
708  if (!ifCond)
709  return failure();
710 
711  IntegerAttr constAttr;
712  if (!matchPattern(ifCond, m_Constant(&constAttr)))
713  return failure();
714  if (constAttr.getInt())
715  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
716  else
717  rewriter.eraseOp(op);
718 
719  return success();
720  }
721 };
722 
723 /// Replaces the given op with the contents of the given single-block region,
724 /// using the operands of the block terminator to replace operation results.
725 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
726  Region &region, ValueRange blockArgs = {}) {
727  assert(llvm::hasSingleElement(region) && "expected single-region block");
728  Block *block = &region.front();
729  Operation *terminator = block->getTerminator();
730  ValueRange results = terminator->getOperands();
731  rewriter.inlineBlockBefore(block, op, blockArgs);
732  rewriter.replaceOp(op, results);
733  rewriter.eraseOp(terminator);
734 }
735 
736 /// Pattern to remove operation with region that have constant false `ifCond`
737 /// and remove the condition from the operation if the `ifCond` is constant
738 /// true.
739 template <typename OpTy>
740 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
742 
743  LogicalResult matchAndRewrite(OpTy op,
744  PatternRewriter &rewriter) const override {
745  // Early return if there is no condition.
746  Value ifCond = op.getIfCond();
747  if (!ifCond)
748  return failure();
749 
750  IntegerAttr constAttr;
751  if (!matchPattern(ifCond, m_Constant(&constAttr)))
752  return failure();
753  if (constAttr.getInt())
754  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
755  else
756  replaceOpWithRegion(rewriter, op, op.getRegion());
757 
758  return success();
759  }
760 };
761 
762 } // namespace
763 
764 //===----------------------------------------------------------------------===//
765 // PrivateRecipeOp
766 //===----------------------------------------------------------------------===//
767 
768 static LogicalResult verifyInitLikeSingleArgRegion(
769  Operation *op, Region &region, StringRef regionType, StringRef regionName,
770  Type type, bool verifyYield, bool optional = false) {
771  if (optional && region.empty())
772  return success();
773 
774  if (region.empty())
775  return op->emitOpError() << "expects non-empty " << regionName << " region";
776  Block &firstBlock = region.front();
777  if (firstBlock.getNumArguments() < 1 ||
778  firstBlock.getArgument(0).getType() != type)
779  return op->emitOpError() << "expects " << regionName
780  << " region first "
781  "argument of the "
782  << regionType << " type";
783 
784  if (verifyYield) {
785  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
786  if (yieldOp.getOperands().size() != 1 ||
787  yieldOp.getOperands().getTypes()[0] != type)
788  return op->emitOpError() << "expects " << regionName
789  << " region to "
790  "yield a value of the "
791  << regionType << " type";
792  }
793  }
794  return success();
795 }
796 
797 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
798  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
799  "privatization", "init", getType(),
800  /*verifyYield=*/false)))
801  return failure();
803  *this, getDestroyRegion(), "privatization", "destroy", getType(),
804  /*verifyYield=*/false, /*optional=*/true)))
805  return failure();
806  return success();
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // FirstprivateRecipeOp
811 //===----------------------------------------------------------------------===//
812 
813 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
814  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
815  "privatization", "init", getType(),
816  /*verifyYield=*/false)))
817  return failure();
818 
819  if (getCopyRegion().empty())
820  return emitOpError() << "expects non-empty copy region";
821 
822  Block &firstBlock = getCopyRegion().front();
823  if (firstBlock.getNumArguments() < 2 ||
824  firstBlock.getArgument(0).getType() != getType())
825  return emitOpError() << "expects copy region with two arguments of the "
826  "privatization type";
827 
828  if (getDestroyRegion().empty())
829  return success();
830 
831  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
832  "privatization", "destroy",
833  getType(), /*verifyYield=*/false)))
834  return failure();
835 
836  return success();
837 }
838 
839 //===----------------------------------------------------------------------===//
840 // ReductionRecipeOp
841 //===----------------------------------------------------------------------===//
842 
843 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
844  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
845  "init", getType(),
846  /*verifyYield=*/false)))
847  return failure();
848 
849  if (getCombinerRegion().empty())
850  return emitOpError() << "expects non-empty combiner region";
851 
852  Block &reductionBlock = getCombinerRegion().front();
853  if (reductionBlock.getNumArguments() < 2 ||
854  reductionBlock.getArgument(0).getType() != getType() ||
855  reductionBlock.getArgument(1).getType() != getType())
856  return emitOpError() << "expects combiner region with the first two "
857  << "arguments of the reduction type";
858 
859  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
860  if (yieldOp.getOperands().size() != 1 ||
861  yieldOp.getOperands().getTypes()[0] != getType())
862  return emitOpError() << "expects combiner region to yield a value "
863  "of the reduction type";
864  }
865 
866  return success();
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // Custom parser and printer verifier for private clause
871 //===----------------------------------------------------------------------===//
872 
873 static ParseResult parseSymOperandList(
874  mlir::OpAsmParser &parser,
876  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
878  if (failed(parser.parseCommaSeparatedList([&]() {
879  if (parser.parseAttribute(attributes.emplace_back()) ||
880  parser.parseArrow() ||
881  parser.parseOperand(operands.emplace_back()) ||
882  parser.parseColonType(types.emplace_back()))
883  return failure();
884  return success();
885  })))
886  return failure();
887  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
888  attributes.end());
889  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
890  return success();
891 }
892 
894  mlir::OperandRange operands,
895  mlir::TypeRange types,
896  std::optional<mlir::ArrayAttr> attributes) {
897  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
898  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
899  << std::get<1>(it).getType();
900  });
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // ParallelOp
905 //===----------------------------------------------------------------------===//
906 
907 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
908 template <typename Op>
909 static LogicalResult checkDataOperands(Op op,
910  const mlir::ValueRange &operands) {
911  for (mlir::Value operand : operands)
912  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
913  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
914  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
915  operand.getDefiningOp()))
916  return op.emitError(
917  "expect data entry/exit operation or acc.getdeviceptr "
918  "as defining op");
919  return success();
920 }
921 
922 template <typename Op>
923 static LogicalResult
924 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
925  mlir::OperandRange operands, llvm::StringRef operandName,
926  llvm::StringRef symbolName, bool checkOperandType = true) {
927  if (!operands.empty()) {
928  if (!attributes || attributes->size() != operands.size())
929  return op->emitOpError()
930  << "expected as many " << symbolName << " symbol reference as "
931  << operandName << " operands";
932  } else {
933  if (attributes)
934  return op->emitOpError()
935  << "unexpected " << symbolName << " symbol reference";
936  return success();
937  }
938 
940  for (auto args : llvm::zip(operands, *attributes)) {
941  mlir::Value operand = std::get<0>(args);
942 
943  if (!set.insert(operand).second)
944  return op->emitOpError()
945  << operandName << " operand appears more than once";
946 
947  mlir::Type varType = operand.getType();
948  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
949  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
950  if (!decl)
951  return op->emitOpError()
952  << "expected symbol reference " << symbolRef << " to point to a "
953  << operandName << " declaration";
954 
955  if (checkOperandType && decl.getType() && decl.getType() != varType)
956  return op->emitOpError() << "expected " << operandName << " (" << varType
957  << ") to be the same type as " << operandName
958  << " declaration (" << decl.getType() << ")";
959  }
960 
961  return success();
962 }
963 
964 unsigned ParallelOp::getNumDataOperands() {
965  return getReductionOperands().size() + getPrivateOperands().size() +
966  getFirstprivateOperands().size() + getDataClauseOperands().size();
967 }
968 
969 Value ParallelOp::getDataOperand(unsigned i) {
970  unsigned numOptional = getAsyncOperands().size();
971  numOptional += getNumGangs().size();
972  numOptional += getNumWorkers().size();
973  numOptional += getVectorLength().size();
974  numOptional += getIfCond() ? 1 : 0;
975  numOptional += getSelfCond() ? 1 : 0;
976  return getOperand(getWaitOperands().size() + numOptional + i);
977 }
978 
979 template <typename Op>
980 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
981  ArrayAttr deviceTypes,
982  llvm::StringRef keyword) {
983  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
984  return op.emitOpError() << keyword << " operands count must match "
985  << keyword << " device_type count";
986  return success();
987 }
988 
989 template <typename Op>
991  Op op, OperandRange operands, DenseI32ArrayAttr segments,
992  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
993  std::size_t numOperandsInSegments = 0;
994  std::size_t nbOfSegments = 0;
995 
996  if (segments) {
997  for (auto segCount : segments.asArrayRef()) {
998  if (maxInSegment != 0 && segCount > maxInSegment)
999  return op.emitOpError() << keyword << " expects a maximum of "
1000  << maxInSegment << " values per segment";
1001  numOperandsInSegments += segCount;
1002  ++nbOfSegments;
1003  }
1004  }
1005 
1006  if ((numOperandsInSegments != operands.size()) ||
1007  (!deviceTypes && !operands.empty()))
1008  return op.emitOpError()
1009  << keyword << " operand count does not match count in segments";
1010  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1011  return op.emitOpError()
1012  << keyword << " segment count does not match device_type count";
1013  return success();
1014 }
1015 
1016 LogicalResult acc::ParallelOp::verify() {
1017  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1018  *this, getPrivatizations(), getPrivateOperands(), "private",
1019  "privatizations", /*checkOperandType=*/false)))
1020  return failure();
1021  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1022  *this, getFirstprivatizations(), getFirstprivateOperands(),
1023  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1024  return failure();
1025  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1026  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1027  "reductions", false)))
1028  return failure();
1029 
1031  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1032  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1033  return failure();
1034 
1036  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1037  getWaitOperandsDeviceTypeAttr(), "wait")))
1038  return failure();
1039 
1040  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1041  getNumWorkersDeviceTypeAttr(),
1042  "num_workers")))
1043  return failure();
1044 
1045  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1046  getVectorLengthDeviceTypeAttr(),
1047  "vector_length")))
1048  return failure();
1049 
1050  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1051  getAsyncOperandsDeviceTypeAttr(),
1052  "async")))
1053  return failure();
1054 
1055  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1056  return failure();
1057 
1058  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1059 }
1060 
1061 static mlir::Value
1062 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1064  mlir::acc::DeviceType deviceType) {
1065  if (!arrayAttr)
1066  return {};
1067  if (auto pos = findSegment(*arrayAttr, deviceType))
1068  return range[*pos];
1069  return {};
1070 }
1071 
1072 bool acc::ParallelOp::hasAsyncOnly() {
1073  return hasAsyncOnly(mlir::acc::DeviceType::None);
1074 }
1075 
1076 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1077  return hasDeviceType(getAsyncOnly(), deviceType);
1078 }
1079 
1080 mlir::Value acc::ParallelOp::getAsyncValue() {
1081  return getAsyncValue(mlir::acc::DeviceType::None);
1082 }
1083 
1084 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1086  getAsyncOperands(), deviceType);
1087 }
1088 
1089 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1090  return getNumWorkersValue(mlir::acc::DeviceType::None);
1091 }
1092 
1094 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1095  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1096  deviceType);
1097 }
1098 
1099 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1100  return getVectorLengthValue(mlir::acc::DeviceType::None);
1101 }
1102 
1104 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1105  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1106  getVectorLength(), deviceType);
1107 }
1108 
1109 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1110  return getNumGangsValues(mlir::acc::DeviceType::None);
1111 }
1112 
1114 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1115  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1116  getNumGangsSegments(), deviceType);
1117 }
1118 
1119 bool acc::ParallelOp::hasWaitOnly() {
1120  return hasWaitOnly(mlir::acc::DeviceType::None);
1121 }
1122 
1123 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1124  return hasDeviceType(getWaitOnly(), deviceType);
1125 }
1126 
1127 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1128  return getWaitValues(mlir::acc::DeviceType::None);
1129 }
1130 
1132 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1134  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1135  getHasWaitDevnum(), deviceType);
1136 }
1137 
1138 mlir::Value ParallelOp::getWaitDevnum() {
1139  return getWaitDevnum(mlir::acc::DeviceType::None);
1140 }
1141 
1142 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1143  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1144  getWaitOperandsSegments(), getHasWaitDevnum(),
1145  deviceType);
1146 }
1147 
1148 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1149  mlir::OperationState &odsState,
1150  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1151  mlir::ValueRange vectorLength,
1152  mlir::ValueRange asyncOperands,
1153  mlir::ValueRange waitOperands, mlir::Value ifCond,
1154  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1155  mlir::ValueRange gangPrivateOperands,
1156  mlir::ValueRange gangFirstPrivateOperands,
1157  mlir::ValueRange dataClauseOperands) {
1158 
1159  ParallelOp::build(
1160  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1161  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1162  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1163  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1164  /*numGangsDeviceType=*/nullptr, numWorkers,
1165  /*numWorkersDeviceType=*/nullptr, vectorLength,
1166  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1167  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1168  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1169  /*firstprivatizations=*/nullptr, dataClauseOperands,
1170  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1171 }
1172 
1173 static ParseResult parseNumGangs(
1174  mlir::OpAsmParser &parser,
1176  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1177  mlir::DenseI32ArrayAttr &segments) {
1180 
1181  do {
1182  if (failed(parser.parseLBrace()))
1183  return failure();
1184 
1185  int32_t crtOperandsSize = operands.size();
1186  if (failed(parser.parseCommaSeparatedList(
1188  if (parser.parseOperand(operands.emplace_back()) ||
1189  parser.parseColonType(types.emplace_back()))
1190  return failure();
1191  return success();
1192  })))
1193  return failure();
1194  seg.push_back(operands.size() - crtOperandsSize);
1195 
1196  if (failed(parser.parseRBrace()))
1197  return failure();
1198 
1199  if (succeeded(parser.parseOptionalLSquare())) {
1200  if (parser.parseAttribute(attributes.emplace_back()) ||
1201  parser.parseRSquare())
1202  return failure();
1203  } else {
1204  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1206  }
1207  } while (succeeded(parser.parseOptionalComma()));
1208 
1209  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1210  attributes.end());
1211  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1212  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1213 
1214  return success();
1215 }
1216 
1218  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1219  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1220  p << " [" << attr << "]";
1221 }
1222 
1224  mlir::OperandRange operands, mlir::TypeRange types,
1225  std::optional<mlir::ArrayAttr> deviceTypes,
1226  std::optional<mlir::DenseI32ArrayAttr> segments) {
1227  unsigned opIdx = 0;
1228  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1229  p << "{";
1230  llvm::interleaveComma(
1231  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1232  p << operands[opIdx] << " : " << operands[opIdx].getType();
1233  ++opIdx;
1234  });
1235  p << "}";
1236  printSingleDeviceType(p, it.value());
1237  });
1238 }
1239 
1241  mlir::OpAsmParser &parser,
1243  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1244  mlir::DenseI32ArrayAttr &segments) {
1247 
1248  do {
1249  if (failed(parser.parseLBrace()))
1250  return failure();
1251 
1252  int32_t crtOperandsSize = operands.size();
1253 
1254  if (failed(parser.parseCommaSeparatedList(
1256  if (parser.parseOperand(operands.emplace_back()) ||
1257  parser.parseColonType(types.emplace_back()))
1258  return failure();
1259  return success();
1260  })))
1261  return failure();
1262 
1263  seg.push_back(operands.size() - crtOperandsSize);
1264 
1265  if (failed(parser.parseRBrace()))
1266  return failure();
1267 
1268  if (succeeded(parser.parseOptionalLSquare())) {
1269  if (parser.parseAttribute(attributes.emplace_back()) ||
1270  parser.parseRSquare())
1271  return failure();
1272  } else {
1273  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1275  }
1276  } while (succeeded(parser.parseOptionalComma()));
1277 
1278  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1279  attributes.end());
1280  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1281  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1282 
1283  return success();
1284 }
1285 
1288  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1289  std::optional<mlir::DenseI32ArrayAttr> segments) {
1290  unsigned opIdx = 0;
1291  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1292  p << "{";
1293  llvm::interleaveComma(
1294  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1295  p << operands[opIdx] << " : " << operands[opIdx].getType();
1296  ++opIdx;
1297  });
1298  p << "}";
1299  printSingleDeviceType(p, it.value());
1300  });
1301 }
1302 
1303 static ParseResult parseWaitClause(
1304  mlir::OpAsmParser &parser,
1306  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1307  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1308  mlir::ArrayAttr &keywordOnly) {
1309  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1311 
1312  bool needCommaBeforeOperands = false;
1313 
1314  // Keyword only
1315  if (failed(parser.parseOptionalLParen())) {
1316  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1318  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1319  return success();
1320  }
1321 
1322  // Parse keyword only attributes
1323  if (succeeded(parser.parseOptionalLSquare())) {
1324  if (failed(parser.parseCommaSeparatedList([&]() {
1325  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1326  return failure();
1327  return success();
1328  })))
1329  return failure();
1330  if (parser.parseRSquare())
1331  return failure();
1332  needCommaBeforeOperands = true;
1333  }
1334 
1335  if (needCommaBeforeOperands && failed(parser.parseComma()))
1336  return failure();
1337 
1338  do {
1339  if (failed(parser.parseLBrace()))
1340  return failure();
1341 
1342  int32_t crtOperandsSize = operands.size();
1343 
1344  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1345  if (failed(parser.parseColon()))
1346  return failure();
1347  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1348  } else {
1349  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1350  }
1351 
1352  if (failed(parser.parseCommaSeparatedList(
1354  if (parser.parseOperand(operands.emplace_back()) ||
1355  parser.parseColonType(types.emplace_back()))
1356  return failure();
1357  return success();
1358  })))
1359  return failure();
1360 
1361  seg.push_back(operands.size() - crtOperandsSize);
1362 
1363  if (failed(parser.parseRBrace()))
1364  return failure();
1365 
1366  if (succeeded(parser.parseOptionalLSquare())) {
1367  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1368  parser.parseRSquare())
1369  return failure();
1370  } else {
1371  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1373  }
1374  } while (succeeded(parser.parseOptionalComma()));
1375 
1376  if (failed(parser.parseRParen()))
1377  return failure();
1378 
1379  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1380  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1381  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1382  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1383 
1384  return success();
1385 }
1386 
1387 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1388  if (!hasDeviceTypeValues(attrs))
1389  return false;
1390  if (attrs->size() != 1)
1391  return false;
1392  if (auto deviceTypeAttr =
1393  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1394  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1395  return false;
1396 }
1397 
1399  mlir::OperandRange operands, mlir::TypeRange types,
1400  std::optional<mlir::ArrayAttr> deviceTypes,
1401  std::optional<mlir::DenseI32ArrayAttr> segments,
1402  std::optional<mlir::ArrayAttr> hasDevNum,
1403  std::optional<mlir::ArrayAttr> keywordOnly) {
1404 
1405  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1406  return;
1407 
1408  p << "(";
1409 
1410  printDeviceTypes(p, keywordOnly);
1411  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1412  p << ", ";
1413 
1414  unsigned opIdx = 0;
1415  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1416  p << "{";
1417  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1418  if (boolAttr && boolAttr.getValue())
1419  p << "devnum: ";
1420  llvm::interleaveComma(
1421  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1422  p << operands[opIdx] << " : " << operands[opIdx].getType();
1423  ++opIdx;
1424  });
1425  p << "}";
1426  printSingleDeviceType(p, it.value());
1427  });
1428 
1429  p << ")";
1430 }
1431 
1432 static ParseResult parseDeviceTypeOperands(
1433  mlir::OpAsmParser &parser,
1435  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1437  if (failed(parser.parseCommaSeparatedList([&]() {
1438  if (parser.parseOperand(operands.emplace_back()) ||
1439  parser.parseColonType(types.emplace_back()))
1440  return failure();
1441  if (succeeded(parser.parseOptionalLSquare())) {
1442  if (parser.parseAttribute(attributes.emplace_back()) ||
1443  parser.parseRSquare())
1444  return failure();
1445  } else {
1446  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1447  parser.getContext(), mlir::acc::DeviceType::None));
1448  }
1449  return success();
1450  })))
1451  return failure();
1452  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1453  attributes.end());
1454  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1455  return success();
1456 }
1457 
1458 static void
1460  mlir::OperandRange operands, mlir::TypeRange types,
1461  std::optional<mlir::ArrayAttr> deviceTypes) {
1462  if (!hasDeviceTypeValues(deviceTypes))
1463  return;
1464  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1465  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1466  printSingleDeviceType(p, std::get<0>(it));
1467  });
1468 }
1469 
1471  mlir::OpAsmParser &parser,
1473  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1474  mlir::ArrayAttr &keywordOnlyDeviceType) {
1475 
1476  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1477  bool needCommaBeforeOperands = false;
1478 
1479  if (failed(parser.parseOptionalLParen())) {
1480  // Keyword only
1481  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1483  keywordOnlyDeviceType =
1484  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1485  return success();
1486  }
1487 
1488  // Parse keyword only attributes
1489  if (succeeded(parser.parseOptionalLSquare())) {
1490  // Parse keyword only attributes
1491  if (failed(parser.parseCommaSeparatedList([&]() {
1492  if (parser.parseAttribute(
1493  keywordOnlyDeviceTypeAttributes.emplace_back()))
1494  return failure();
1495  return success();
1496  })))
1497  return failure();
1498  if (parser.parseRSquare())
1499  return failure();
1500  needCommaBeforeOperands = true;
1501  }
1502 
1503  if (needCommaBeforeOperands && failed(parser.parseComma()))
1504  return failure();
1505 
1507  if (failed(parser.parseCommaSeparatedList([&]() {
1508  if (parser.parseOperand(operands.emplace_back()) ||
1509  parser.parseColonType(types.emplace_back()))
1510  return failure();
1511  if (succeeded(parser.parseOptionalLSquare())) {
1512  if (parser.parseAttribute(attributes.emplace_back()) ||
1513  parser.parseRSquare())
1514  return failure();
1515  } else {
1516  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1517  parser.getContext(), mlir::acc::DeviceType::None));
1518  }
1519  return success();
1520  })))
1521  return failure();
1522 
1523  if (failed(parser.parseRParen()))
1524  return failure();
1525 
1526  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1527  attributes.end());
1528  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1529  return success();
1530 }
1531 
1534  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1535  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1536 
1537  if (operands.begin() == operands.end() &&
1538  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1539  return;
1540  }
1541 
1542  p << "(";
1543  printDeviceTypes(p, keywordOnlyDeviceTypes);
1544  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1545  hasDeviceTypeValues(deviceTypes))
1546  p << ", ";
1547  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1548  p << ")";
1549 }
1550 
1551 static ParseResult
1553  mlir::acc::CombinedConstructsTypeAttr &attr) {
1554  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1555  if (parser.parseLParen())
1556  return failure();
1557  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1559  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1560  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1562  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1563  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1565  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1566  } else {
1567  parser.emitError(parser.getCurrentLocation(),
1568  "expected compute construct name");
1569  return failure();
1570  }
1571  if (parser.parseRParen())
1572  return failure();
1573  }
1574  return success();
1575 }
1576 
1577 static void
1579  mlir::acc::CombinedConstructsTypeAttr attr) {
1580  if (attr) {
1581  switch (attr.getValue()) {
1582  case mlir::acc::CombinedConstructsType::KernelsLoop:
1583  p << "combined(kernels)";
1584  break;
1585  case mlir::acc::CombinedConstructsType::ParallelLoop:
1586  p << "combined(parallel)";
1587  break;
1588  case mlir::acc::CombinedConstructsType::SerialLoop:
1589  p << "combined(serial)";
1590  break;
1591  };
1592  }
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // SerialOp
1597 //===----------------------------------------------------------------------===//
1598 
1599 unsigned SerialOp::getNumDataOperands() {
1600  return getReductionOperands().size() + getPrivateOperands().size() +
1601  getFirstprivateOperands().size() + getDataClauseOperands().size();
1602 }
1603 
1604 Value SerialOp::getDataOperand(unsigned i) {
1605  unsigned numOptional = getAsyncOperands().size();
1606  numOptional += getIfCond() ? 1 : 0;
1607  numOptional += getSelfCond() ? 1 : 0;
1608  return getOperand(getWaitOperands().size() + numOptional + i);
1609 }
1610 
1611 bool acc::SerialOp::hasAsyncOnly() {
1612  return hasAsyncOnly(mlir::acc::DeviceType::None);
1613 }
1614 
1615 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1616  return hasDeviceType(getAsyncOnly(), deviceType);
1617 }
1618 
1619 mlir::Value acc::SerialOp::getAsyncValue() {
1620  return getAsyncValue(mlir::acc::DeviceType::None);
1621 }
1622 
1623 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1625  getAsyncOperands(), deviceType);
1626 }
1627 
1628 bool acc::SerialOp::hasWaitOnly() {
1629  return hasWaitOnly(mlir::acc::DeviceType::None);
1630 }
1631 
1632 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1633  return hasDeviceType(getWaitOnly(), deviceType);
1634 }
1635 
1636 mlir::Operation::operand_range SerialOp::getWaitValues() {
1637  return getWaitValues(mlir::acc::DeviceType::None);
1638 }
1639 
1641 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1643  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1644  getHasWaitDevnum(), deviceType);
1645 }
1646 
1647 mlir::Value SerialOp::getWaitDevnum() {
1648  return getWaitDevnum(mlir::acc::DeviceType::None);
1649 }
1650 
1651 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1652  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1653  getWaitOperandsSegments(), getHasWaitDevnum(),
1654  deviceType);
1655 }
1656 
1657 LogicalResult acc::SerialOp::verify() {
1658  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1659  *this, getPrivatizations(), getPrivateOperands(), "private",
1660  "privatizations", /*checkOperandType=*/false)))
1661  return failure();
1662  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1663  *this, getFirstprivatizations(), getFirstprivateOperands(),
1664  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1665  return failure();
1666  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1667  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1668  "reductions", false)))
1669  return failure();
1670 
1672  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1673  getWaitOperandsDeviceTypeAttr(), "wait")))
1674  return failure();
1675 
1676  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1677  getAsyncOperandsDeviceTypeAttr(),
1678  "async")))
1679  return failure();
1680 
1681  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1682  return failure();
1683 
1684  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1685 }
1686 
1687 //===----------------------------------------------------------------------===//
1688 // KernelsOp
1689 //===----------------------------------------------------------------------===//
1690 
1691 unsigned KernelsOp::getNumDataOperands() {
1692  return getDataClauseOperands().size();
1693 }
1694 
1695 Value KernelsOp::getDataOperand(unsigned i) {
1696  unsigned numOptional = getAsyncOperands().size();
1697  numOptional += getWaitOperands().size();
1698  numOptional += getNumGangs().size();
1699  numOptional += getNumWorkers().size();
1700  numOptional += getVectorLength().size();
1701  numOptional += getIfCond() ? 1 : 0;
1702  numOptional += getSelfCond() ? 1 : 0;
1703  return getOperand(numOptional + i);
1704 }
1705 
1706 bool acc::KernelsOp::hasAsyncOnly() {
1707  return hasAsyncOnly(mlir::acc::DeviceType::None);
1708 }
1709 
1710 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1711  return hasDeviceType(getAsyncOnly(), deviceType);
1712 }
1713 
1714 mlir::Value acc::KernelsOp::getAsyncValue() {
1715  return getAsyncValue(mlir::acc::DeviceType::None);
1716 }
1717 
1718 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1720  getAsyncOperands(), deviceType);
1721 }
1722 
1723 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1724  return getNumWorkersValue(mlir::acc::DeviceType::None);
1725 }
1726 
1728 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1729  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1730  deviceType);
1731 }
1732 
1733 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1734  return getVectorLengthValue(mlir::acc::DeviceType::None);
1735 }
1736 
1738 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1739  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1740  getVectorLength(), deviceType);
1741 }
1742 
1743 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1744  return getNumGangsValues(mlir::acc::DeviceType::None);
1745 }
1746 
1748 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1749  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1750  getNumGangsSegments(), deviceType);
1751 }
1752 
1753 bool acc::KernelsOp::hasWaitOnly() {
1754  return hasWaitOnly(mlir::acc::DeviceType::None);
1755 }
1756 
1757 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1758  return hasDeviceType(getWaitOnly(), deviceType);
1759 }
1760 
1761 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1762  return getWaitValues(mlir::acc::DeviceType::None);
1763 }
1764 
1766 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1768  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1769  getHasWaitDevnum(), deviceType);
1770 }
1771 
1772 mlir::Value KernelsOp::getWaitDevnum() {
1773  return getWaitDevnum(mlir::acc::DeviceType::None);
1774 }
1775 
1776 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1777  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1778  getWaitOperandsSegments(), getHasWaitDevnum(),
1779  deviceType);
1780 }
1781 
1782 LogicalResult acc::KernelsOp::verify() {
1784  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1785  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1786  return failure();
1787 
1789  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1790  getWaitOperandsDeviceTypeAttr(), "wait")))
1791  return failure();
1792 
1793  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1794  getNumWorkersDeviceTypeAttr(),
1795  "num_workers")))
1796  return failure();
1797 
1798  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1799  getVectorLengthDeviceTypeAttr(),
1800  "vector_length")))
1801  return failure();
1802 
1803  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1804  getAsyncOperandsDeviceTypeAttr(),
1805  "async")))
1806  return failure();
1807 
1808  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1809  return failure();
1810 
1811  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1812 }
1813 
1814 //===----------------------------------------------------------------------===//
1815 // HostDataOp
1816 //===----------------------------------------------------------------------===//
1817 
1818 LogicalResult acc::HostDataOp::verify() {
1819  if (getDataClauseOperands().empty())
1820  return emitError("at least one operand must appear on the host_data "
1821  "operation");
1822 
1823  for (mlir::Value operand : getDataClauseOperands())
1824  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1825  return emitError("expect data entry operation as defining op");
1826  return success();
1827 }
1828 
1829 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1830  MLIRContext *context) {
1831  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1832 }
1833 
1834 //===----------------------------------------------------------------------===//
1835 // LoopOp
1836 //===----------------------------------------------------------------------===//
1837 
1838 static ParseResult parseGangValue(
1839  OpAsmParser &parser, llvm::StringRef keyword,
1842  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1843  bool &needCommaBetweenValues, bool &newValue) {
1844  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1845  if (parser.parseEqual())
1846  return failure();
1847  if (parser.parseOperand(operands.emplace_back()) ||
1848  parser.parseColonType(types.emplace_back()))
1849  return failure();
1850  attributes.push_back(gangArgType);
1851  needCommaBetweenValues = true;
1852  newValue = true;
1853  }
1854  return success();
1855 }
1856 
1857 static ParseResult parseGangClause(
1858  OpAsmParser &parser,
1860  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1861  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1862  mlir::ArrayAttr &gangOnlyDeviceType) {
1863  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1864  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1865  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1867  bool needCommaBetweenValues = false;
1868  bool needCommaBeforeOperands = false;
1869 
1870  if (failed(parser.parseOptionalLParen())) {
1871  // Gang only keyword
1872  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1874  gangOnlyDeviceType =
1875  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1876  return success();
1877  }
1878 
1879  // Parse gang only attributes
1880  if (succeeded(parser.parseOptionalLSquare())) {
1881  // Parse gang only attributes
1882  if (failed(parser.parseCommaSeparatedList([&]() {
1883  if (parser.parseAttribute(
1884  gangOnlyDeviceTypeAttributes.emplace_back()))
1885  return failure();
1886  return success();
1887  })))
1888  return failure();
1889  if (parser.parseRSquare())
1890  return failure();
1891  needCommaBeforeOperands = true;
1892  }
1893 
1894  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1895  mlir::acc::GangArgType::Num);
1896  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1897  mlir::acc::GangArgType::Dim);
1898  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1899  parser.getContext(), mlir::acc::GangArgType::Static);
1900 
1901  do {
1902  if (needCommaBeforeOperands) {
1903  needCommaBeforeOperands = false;
1904  continue;
1905  }
1906 
1907  if (failed(parser.parseLBrace()))
1908  return failure();
1909 
1910  int32_t crtOperandsSize = gangOperands.size();
1911  while (true) {
1912  bool newValue = false;
1913  bool needValue = false;
1914  if (needCommaBetweenValues) {
1915  if (succeeded(parser.parseOptionalComma()))
1916  needValue = true; // expect a new value after comma.
1917  else
1918  break;
1919  }
1920 
1921  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1922  gangOperands, gangOperandsType,
1923  gangArgTypeAttributes, argNum,
1924  needCommaBetweenValues, newValue)))
1925  return failure();
1926  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1927  gangOperands, gangOperandsType,
1928  gangArgTypeAttributes, argDim,
1929  needCommaBetweenValues, newValue)))
1930  return failure();
1931  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1932  gangOperands, gangOperandsType,
1933  gangArgTypeAttributes, argStatic,
1934  needCommaBetweenValues, newValue)))
1935  return failure();
1936 
1937  if (!newValue && needValue) {
1938  parser.emitError(parser.getCurrentLocation(),
1939  "new value expected after comma");
1940  return failure();
1941  }
1942 
1943  if (!newValue)
1944  break;
1945  }
1946 
1947  if (gangOperands.empty())
1948  return parser.emitError(
1949  parser.getCurrentLocation(),
1950  "expect at least one of num, dim or static values");
1951 
1952  if (failed(parser.parseRBrace()))
1953  return failure();
1954 
1955  if (succeeded(parser.parseOptionalLSquare())) {
1956  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1957  parser.parseRSquare())
1958  return failure();
1959  } else {
1960  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1962  }
1963 
1964  seg.push_back(gangOperands.size() - crtOperandsSize);
1965 
1966  } while (succeeded(parser.parseOptionalComma()));
1967 
1968  if (failed(parser.parseRParen()))
1969  return failure();
1970 
1971  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1972  gangArgTypeAttributes.end());
1973  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1974  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1975 
1977  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1978  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1979 
1980  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1981  return success();
1982 }
1983 
1985  mlir::OperandRange operands, mlir::TypeRange types,
1986  std::optional<mlir::ArrayAttr> gangArgTypes,
1987  std::optional<mlir::ArrayAttr> deviceTypes,
1988  std::optional<mlir::DenseI32ArrayAttr> segments,
1989  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1990 
1991  if (operands.begin() == operands.end() &&
1992  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1993  return;
1994  }
1995 
1996  p << "(";
1997 
1998  printDeviceTypes(p, gangOnlyDeviceTypes);
1999 
2000  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2001  hasDeviceTypeValues(deviceTypes))
2002  p << ", ";
2003 
2004  if (hasDeviceTypeValues(deviceTypes)) {
2005  unsigned opIdx = 0;
2006  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2007  p << "{";
2008  llvm::interleaveComma(
2009  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2010  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2011  (*gangArgTypes)[opIdx]);
2012  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2013  p << LoopOp::getGangNumKeyword();
2014  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2015  p << LoopOp::getGangDimKeyword();
2016  else if (gangArgTypeAttr.getValue() ==
2017  mlir::acc::GangArgType::Static)
2018  p << LoopOp::getGangStaticKeyword();
2019  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2020  ++opIdx;
2021  });
2022  p << "}";
2023  printSingleDeviceType(p, it.value());
2024  });
2025  }
2026  p << ")";
2027 }
2028 
2030  std::optional<mlir::ArrayAttr> segments,
2031  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2032  if (!segments)
2033  return false;
2034  for (auto attr : *segments) {
2035  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2036  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2037  return true;
2038  }
2039  return false;
2040 }
2041 
2042 /// Check for duplicates in the DeviceType array attribute.
2043 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2044  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2045  if (!deviceTypes)
2046  return success();
2047  for (auto attr : deviceTypes) {
2048  auto deviceTypeAttr =
2049  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2050  if (!deviceTypeAttr)
2051  return failure();
2052  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2053  return failure();
2054  }
2055  return success();
2056 }
2057 
2058 LogicalResult acc::LoopOp::verify() {
2059  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2060  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2061  return emitError() << "inclusiveUpperbound size is expected to be the same"
2062  << " as upperbound size";
2063 
2064  // Check collapse
2065  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2066  return emitOpError() << "collapse device_type attr must be define when"
2067  << " collapse attr is present";
2068 
2069  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2070  getCollapseAttr().getValue().size() !=
2071  getCollapseDeviceTypeAttr().getValue().size())
2072  return emitOpError() << "collapse attribute count must match collapse"
2073  << " device_type count";
2074  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2075  return emitOpError()
2076  << "duplicate device_type found in collapseDeviceType attribute";
2077 
2078  // Check gang
2079  if (!getGangOperands().empty()) {
2080  if (!getGangOperandsArgType())
2081  return emitOpError() << "gangOperandsArgType attribute must be defined"
2082  << " when gang operands are present";
2083 
2084  if (getGangOperands().size() !=
2085  getGangOperandsArgTypeAttr().getValue().size())
2086  return emitOpError() << "gangOperandsArgType attribute count must match"
2087  << " gangOperands count";
2088  }
2089  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2090  return emitOpError() << "duplicate device_type found in gang attribute";
2091 
2093  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2094  getGangOperandsDeviceTypeAttr(), "gang")))
2095  return failure();
2096 
2097  // Check worker
2098  if (failed(checkDeviceTypes(getWorkerAttr())))
2099  return emitOpError() << "duplicate device_type found in worker attribute";
2100  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2101  return emitOpError() << "duplicate device_type found in "
2102  "workerNumOperandsDeviceType attribute";
2103  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2104  getWorkerNumOperandsDeviceTypeAttr(),
2105  "worker")))
2106  return failure();
2107 
2108  // Check vector
2109  if (failed(checkDeviceTypes(getVectorAttr())))
2110  return emitOpError() << "duplicate device_type found in vector attribute";
2111  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2112  return emitOpError() << "duplicate device_type found in "
2113  "vectorOperandsDeviceType attribute";
2114  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2115  getVectorOperandsDeviceTypeAttr(),
2116  "vector")))
2117  return failure();
2118 
2120  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2121  getTileOperandsDeviceTypeAttr(), "tile")))
2122  return failure();
2123 
2124  // auto, independent and seq attribute are mutually exclusive.
2125  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2126  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2127  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2128  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2129  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
2130  << "\", " << getIndependentAttrName() << ", "
2131  << getSeqAttrName()
2132  << " can be present at the same time";
2133  }
2134 
2135  // Gang, worker and vector are incompatible with seq.
2136  if (getSeqAttr()) {
2137  for (auto attr : getSeqAttr()) {
2138  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2139  if (hasVector(deviceTypeAttr.getValue()) ||
2140  getVectorValue(deviceTypeAttr.getValue()) ||
2141  hasWorker(deviceTypeAttr.getValue()) ||
2142  getWorkerValue(deviceTypeAttr.getValue()) ||
2143  hasGang(deviceTypeAttr.getValue()) ||
2144  getGangValue(mlir::acc::GangArgType::Num,
2145  deviceTypeAttr.getValue()) ||
2146  getGangValue(mlir::acc::GangArgType::Dim,
2147  deviceTypeAttr.getValue()) ||
2148  getGangValue(mlir::acc::GangArgType::Static,
2149  deviceTypeAttr.getValue()))
2150  return emitError()
2151  << "gang, worker or vector cannot appear with the seq attr";
2152  }
2153  }
2154 
2155  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2156  *this, getPrivatizations(), getPrivateOperands(), "private",
2157  "privatizations", false)))
2158  return failure();
2159 
2160  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2161  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2162  "reductions", false)))
2163  return failure();
2164 
2165  if (getCombined().has_value() &&
2166  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2167  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2168  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2169  return emitError("unexpected combined constructs attribute");
2170  }
2171 
2172  // Check non-empty body().
2173  if (getRegion().empty())
2174  return emitError("expected non-empty body.");
2175 
2176  return success();
2177 }
2178 
2179 unsigned LoopOp::getNumDataOperands() {
2180  return getReductionOperands().size() + getPrivateOperands().size();
2181 }
2182 
2183 Value LoopOp::getDataOperand(unsigned i) {
2184  unsigned numOptional =
2185  getLowerbound().size() + getUpperbound().size() + getStep().size();
2186  numOptional += getGangOperands().size();
2187  numOptional += getVectorOperands().size();
2188  numOptional += getWorkerNumOperands().size();
2189  numOptional += getTileOperands().size();
2190  numOptional += getCacheOperands().size();
2191  return getOperand(numOptional + i);
2192 }
2193 
2194 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2195 
2196 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2197  return hasDeviceType(getAuto_(), deviceType);
2198 }
2199 
2200 bool LoopOp::hasIndependent() {
2201  return hasIndependent(mlir::acc::DeviceType::None);
2202 }
2203 
2204 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2205  return hasDeviceType(getIndependent(), deviceType);
2206 }
2207 
2208 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2209 
2210 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2211  return hasDeviceType(getSeq(), deviceType);
2212 }
2213 
2214 mlir::Value LoopOp::getVectorValue() {
2215  return getVectorValue(mlir::acc::DeviceType::None);
2216 }
2217 
2218 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2219  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2220  getVectorOperands(), deviceType);
2221 }
2222 
2223 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2224 
2225 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2226  return hasDeviceType(getVector(), deviceType);
2227 }
2228 
2229 mlir::Value LoopOp::getWorkerValue() {
2230  return getWorkerValue(mlir::acc::DeviceType::None);
2231 }
2232 
2233 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2234  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2235  getWorkerNumOperands(), deviceType);
2236 }
2237 
2238 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2239 
2240 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2241  return hasDeviceType(getWorker(), deviceType);
2242 }
2243 
2244 mlir::Operation::operand_range LoopOp::getTileValues() {
2245  return getTileValues(mlir::acc::DeviceType::None);
2246 }
2247 
2249 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2250  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2251  getTileOperandsSegments(), deviceType);
2252 }
2253 
2254 std::optional<int64_t> LoopOp::getCollapseValue() {
2255  return getCollapseValue(mlir::acc::DeviceType::None);
2256 }
2257 
2258 std::optional<int64_t>
2259 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2260  if (!getCollapseAttr())
2261  return std::nullopt;
2262  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2263  auto intAttr =
2264  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2265  return intAttr.getValue().getZExtValue();
2266  }
2267  return std::nullopt;
2268 }
2269 
2270 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2271  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2272 }
2273 
2274 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2275  mlir::acc::DeviceType deviceType) {
2276  if (getGangOperands().empty())
2277  return {};
2278  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2279  int32_t nbOperandsBefore = 0;
2280  for (unsigned i = 0; i < *pos; ++i)
2281  nbOperandsBefore += (*getGangOperandsSegments())[i];
2283  getGangOperands()
2284  .drop_front(nbOperandsBefore)
2285  .take_front((*getGangOperandsSegments())[*pos]);
2286 
2287  int32_t argTypeIdx = nbOperandsBefore;
2288  for (auto value : values) {
2289  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2290  (*getGangOperandsArgType())[argTypeIdx]);
2291  if (gangArgTypeAttr.getValue() == gangArgType)
2292  return value;
2293  ++argTypeIdx;
2294  }
2295  }
2296  return {};
2297 }
2298 
2299 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2300 
2301 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2302  return hasDeviceType(getGang(), deviceType);
2303 }
2304 
2305 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2306  return {&getRegion()};
2307 }
2308 
2309 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2310 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2311 /// `(` ssa-id-and-type-list `)`
2312 /// region
2313 ParseResult
2316  SmallVectorImpl<Type> &lowerboundType,
2318  SmallVectorImpl<Type> &upperboundType,
2320  SmallVectorImpl<Type> &stepType) {
2321 
2322  SmallVector<OpAsmParser::Argument> inductionVars;
2323  if (succeeded(
2324  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2325  if (parser.parseLParen() ||
2326  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2327  /*allowType=*/true) ||
2328  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2329  parser.parseOperandList(lowerbound, inductionVars.size(),
2331  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2332  parser.parseKeyword("to") || parser.parseLParen() ||
2333  parser.parseOperandList(upperbound, inductionVars.size(),
2335  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2336  parser.parseKeyword("step") || parser.parseLParen() ||
2337  parser.parseOperandList(step, inductionVars.size(),
2339  parser.parseColonTypeList(stepType) || parser.parseRParen())
2340  return failure();
2341  }
2342  return parser.parseRegion(region, inductionVars);
2343 }
2344 
2346  ValueRange lowerbound, TypeRange lowerboundType,
2347  ValueRange upperbound, TypeRange upperboundType,
2348  ValueRange steps, TypeRange stepType) {
2349  ValueRange regionArgs = region.front().getArguments();
2350  if (!regionArgs.empty()) {
2351  p << acc::LoopOp::getControlKeyword() << "(";
2352  llvm::interleaveComma(regionArgs, p,
2353  [&p](Value v) { p << v << " : " << v.getType(); });
2354  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2355  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2356  << " : " << stepType << ") ";
2357  }
2358  p.printRegion(region, /*printEntryBlockArgs=*/false);
2359 }
2360 
2361 //===----------------------------------------------------------------------===//
2362 // DataOp
2363 //===----------------------------------------------------------------------===//
2364 
2365 LogicalResult acc::DataOp::verify() {
2366  // 2.6.5. Data Construct restriction
2367  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2368  // attach, or default clause must appear on a data construct.
2369  if (getOperands().empty() && !getDefaultAttr())
2370  return emitError("at least one operand or the default attribute "
2371  "must appear on the data operation");
2372 
2373  for (mlir::Value operand : getDataClauseOperands())
2374  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2375  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2376  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2377  operand.getDefiningOp()))
2378  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2379  "as defining op");
2380 
2381  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2382  return failure();
2383 
2384  return success();
2385 }
2386 
2387 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2388 
2389 Value DataOp::getDataOperand(unsigned i) {
2390  unsigned numOptional = getIfCond() ? 1 : 0;
2391  numOptional += getAsyncOperands().size() ? 1 : 0;
2392  numOptional += getWaitOperands().size();
2393  return getOperand(numOptional + i);
2394 }
2395 
2396 bool acc::DataOp::hasAsyncOnly() {
2397  return hasAsyncOnly(mlir::acc::DeviceType::None);
2398 }
2399 
2400 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2401  return hasDeviceType(getAsyncOnly(), deviceType);
2402 }
2403 
2404 mlir::Value DataOp::getAsyncValue() {
2405  return getAsyncValue(mlir::acc::DeviceType::None);
2406 }
2407 
2408 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2410  getAsyncOperands(), deviceType);
2411 }
2412 
2413 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2414 
2415 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2416  return hasDeviceType(getWaitOnly(), deviceType);
2417 }
2418 
2419 mlir::Operation::operand_range DataOp::getWaitValues() {
2420  return getWaitValues(mlir::acc::DeviceType::None);
2421 }
2422 
2424 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2426  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2427  getHasWaitDevnum(), deviceType);
2428 }
2429 
2430 mlir::Value DataOp::getWaitDevnum() {
2431  return getWaitDevnum(mlir::acc::DeviceType::None);
2432 }
2433 
2434 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2435  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2436  getWaitOperandsSegments(), getHasWaitDevnum(),
2437  deviceType);
2438 }
2439 
2440 //===----------------------------------------------------------------------===//
2441 // ExitDataOp
2442 //===----------------------------------------------------------------------===//
2443 
2444 LogicalResult acc::ExitDataOp::verify() {
2445  // 2.6.6. Data Exit Directive restriction
2446  // At least one copyout, delete, or detach clause must appear on an exit data
2447  // directive.
2448  if (getDataClauseOperands().empty())
2449  return emitError("at least one operand must be present in dataOperands on "
2450  "the exit data operation");
2451 
2452  // The async attribute represent the async clause without value. Therefore the
2453  // attribute and operand cannot appear at the same time.
2454  if (getAsyncOperand() && getAsync())
2455  return emitError("async attribute cannot appear with asyncOperand");
2456 
2457  // The wait attribute represent the wait clause without values. Therefore the
2458  // attribute and operands cannot appear at the same time.
2459  if (!getWaitOperands().empty() && getWait())
2460  return emitError("wait attribute cannot appear with waitOperands");
2461 
2462  if (getWaitDevnum() && getWaitOperands().empty())
2463  return emitError("wait_devnum cannot appear without waitOperands");
2464 
2465  return success();
2466 }
2467 
2468 unsigned ExitDataOp::getNumDataOperands() {
2469  return getDataClauseOperands().size();
2470 }
2471 
2472 Value ExitDataOp::getDataOperand(unsigned i) {
2473  unsigned numOptional = getIfCond() ? 1 : 0;
2474  numOptional += getAsyncOperand() ? 1 : 0;
2475  numOptional += getWaitDevnum() ? 1 : 0;
2476  return getOperand(getWaitOperands().size() + numOptional + i);
2477 }
2478 
2479 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2480  MLIRContext *context) {
2481  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2482 }
2483 
2484 //===----------------------------------------------------------------------===//
2485 // EnterDataOp
2486 //===----------------------------------------------------------------------===//
2487 
2488 LogicalResult acc::EnterDataOp::verify() {
2489  // 2.6.6. Data Enter Directive restriction
2490  // At least one copyin, create, or attach clause must appear on an enter data
2491  // directive.
2492  if (getDataClauseOperands().empty())
2493  return emitError("at least one operand must be present in dataOperands on "
2494  "the enter data operation");
2495 
2496  // The async attribute represent the async clause without value. Therefore the
2497  // attribute and operand cannot appear at the same time.
2498  if (getAsyncOperand() && getAsync())
2499  return emitError("async attribute cannot appear with asyncOperand");
2500 
2501  // The wait attribute represent the wait clause without values. Therefore the
2502  // attribute and operands cannot appear at the same time.
2503  if (!getWaitOperands().empty() && getWait())
2504  return emitError("wait attribute cannot appear with waitOperands");
2505 
2506  if (getWaitDevnum() && getWaitOperands().empty())
2507  return emitError("wait_devnum cannot appear without waitOperands");
2508 
2509  for (mlir::Value operand : getDataClauseOperands())
2510  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2511  operand.getDefiningOp()))
2512  return emitError("expect data entry operation as defining op");
2513 
2514  return success();
2515 }
2516 
2517 unsigned EnterDataOp::getNumDataOperands() {
2518  return getDataClauseOperands().size();
2519 }
2520 
2521 Value EnterDataOp::getDataOperand(unsigned i) {
2522  unsigned numOptional = getIfCond() ? 1 : 0;
2523  numOptional += getAsyncOperand() ? 1 : 0;
2524  numOptional += getWaitDevnum() ? 1 : 0;
2525  return getOperand(getWaitOperands().size() + numOptional + i);
2526 }
2527 
2528 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2529  MLIRContext *context) {
2530  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2531 }
2532 
2533 //===----------------------------------------------------------------------===//
2534 // AtomicReadOp
2535 //===----------------------------------------------------------------------===//
2536 
2537 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // AtomicWriteOp
2541 //===----------------------------------------------------------------------===//
2542 
2543 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2544 
2545 //===----------------------------------------------------------------------===//
2546 // AtomicUpdateOp
2547 //===----------------------------------------------------------------------===//
2548 
2549 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2550  PatternRewriter &rewriter) {
2551  if (op.isNoOp()) {
2552  rewriter.eraseOp(op);
2553  return success();
2554  }
2555 
2556  if (Value writeVal = op.getWriteOpVal()) {
2557  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2558  return success();
2559  }
2560 
2561  return failure();
2562 }
2563 
2564 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2565 
2566 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2567 
2568 //===----------------------------------------------------------------------===//
2569 // AtomicCaptureOp
2570 //===----------------------------------------------------------------------===//
2571 
2572 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2573  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2574  return op;
2575  return dyn_cast<AtomicReadOp>(getSecondOp());
2576 }
2577 
2578 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2579  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2580  return op;
2581  return dyn_cast<AtomicWriteOp>(getSecondOp());
2582 }
2583 
2584 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2585  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2586  return op;
2587  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2588 }
2589 
2590 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2591 
2592 //===----------------------------------------------------------------------===//
2593 // DeclareEnterOp
2594 //===----------------------------------------------------------------------===//
2595 
2596 template <typename Op>
2597 static LogicalResult
2599  bool requireAtLeastOneOperand = true) {
2600  if (operands.empty() && requireAtLeastOneOperand)
2601  return emitError(
2602  op->getLoc(),
2603  "at least one operand must appear on the declare operation");
2604 
2605  for (mlir::Value operand : operands) {
2606  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2607  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2608  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2609  operand.getDefiningOp()))
2610  return op.emitError(
2611  "expect valid declare data entry operation or acc.getdeviceptr "
2612  "as defining op");
2613 
2614  mlir::Value var{getVar(operand.getDefiningOp())};
2615  assert(var && "declare operands can only be data entry operations which "
2616  "must have var");
2617  std::optional<mlir::acc::DataClause> dataClauseOptional{
2618  getDataClause(operand.getDefiningOp())};
2619  assert(dataClauseOptional.has_value() &&
2620  "declare operands can only be data entry operations which must have "
2621  "dataClause");
2622 
2623  // If varPtr has no defining op - there is nothing to check further.
2624  if (!var.getDefiningOp())
2625  continue;
2626 
2627  // Check that the varPtr has a declare attribute.
2628  auto declareAttribute{
2629  var.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2630  if (!declareAttribute)
2631  return op.emitError(
2632  "expect declare attribute on variable in declare operation");
2633 
2634  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2635  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2636  return op.emitError(
2637  "expect matching declare attribute on variable in declare operation");
2638 
2639  // If the variable is marked with implicit attribute, the matching declare
2640  // data action must also be marked implicit. The reverse is not checked
2641  // since implicit data action may be inserted to do actions like updating
2642  // device copy, in which case the variable is not necessarily implicitly
2643  // declare'd.
2644  if (declAttr.getImplicit() &&
2645  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2646  return op.emitError(
2647  "implicitness must match between declare op and flag on variable");
2648  }
2649 
2650  return success();
2651 }
2652 
2653 LogicalResult acc::DeclareEnterOp::verify() {
2654  return checkDeclareOperands(*this, this->getDataClauseOperands());
2655 }
2656 
2657 //===----------------------------------------------------------------------===//
2658 // DeclareExitOp
2659 //===----------------------------------------------------------------------===//
2660 
2661 LogicalResult acc::DeclareExitOp::verify() {
2662  if (getToken())
2663  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2664  /*requireAtLeastOneOperand=*/false);
2665  return checkDeclareOperands(*this, this->getDataClauseOperands());
2666 }
2667 
2668 //===----------------------------------------------------------------------===//
2669 // DeclareOp
2670 //===----------------------------------------------------------------------===//
2671 
2672 LogicalResult acc::DeclareOp::verify() {
2673  return checkDeclareOperands(*this, this->getDataClauseOperands());
2674 }
2675 
2676 //===----------------------------------------------------------------------===//
2677 // RoutineOp
2678 //===----------------------------------------------------------------------===//
2679 
2680 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2681  acc::DeviceType dtype) {
2682  unsigned parallelism = 0;
2683  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2684  parallelism += op.hasWorker(dtype) ? 1 : 0;
2685  parallelism += op.hasVector(dtype) ? 1 : 0;
2686  parallelism += op.hasSeq(dtype) ? 1 : 0;
2687  return parallelism;
2688 }
2689 
2690 LogicalResult acc::RoutineOp::verify() {
2691  unsigned baseParallelism =
2693 
2694  if (baseParallelism > 1)
2695  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2696  "be present at the same time";
2697 
2698  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2699  ++dtypeInt) {
2700  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2701  if (dtype == acc::DeviceType::None)
2702  continue;
2703  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2704 
2705  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2706  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2707  "be present at the same time";
2708  }
2709 
2710  return success();
2711 }
2712 
2713 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2714  mlir::ArrayAttr &deviceTypes) {
2715  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2716  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2717 
2718  if (failed(parser.parseCommaSeparatedList([&]() {
2719  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2720  return failure();
2721  if (failed(parser.parseOptionalLSquare())) {
2722  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2723  parser.getContext(), mlir::acc::DeviceType::None));
2724  } else {
2725  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2726  parser.parseRSquare())
2727  return failure();
2728  }
2729  return success();
2730  })))
2731  return failure();
2732 
2733  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2734  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2735 
2736  return success();
2737 }
2738 
2740  std::optional<mlir::ArrayAttr> bindName,
2741  std::optional<mlir::ArrayAttr> deviceTypes) {
2742  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2743  [&](const auto &pair) {
2744  p << std::get<0>(pair);
2745  printSingleDeviceType(p, std::get<1>(pair));
2746  });
2747 }
2748 
2749 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2750  mlir::ArrayAttr &gang,
2751  mlir::ArrayAttr &gangDim,
2752  mlir::ArrayAttr &gangDimDeviceTypes) {
2753 
2754  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2755  gangDimDeviceTypeAttrs;
2756  bool needCommaBeforeOperands = false;
2757 
2758  // Gang keyword only
2759  if (failed(parser.parseOptionalLParen())) {
2760  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2762  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2763  return success();
2764  }
2765 
2766  // Parse keyword only attributes
2767  if (succeeded(parser.parseOptionalLSquare())) {
2768  if (failed(parser.parseCommaSeparatedList([&]() {
2769  if (parser.parseAttribute(gangAttrs.emplace_back()))
2770  return failure();
2771  return success();
2772  })))
2773  return failure();
2774  if (parser.parseRSquare())
2775  return failure();
2776  needCommaBeforeOperands = true;
2777  }
2778 
2779  if (needCommaBeforeOperands && failed(parser.parseComma()))
2780  return failure();
2781 
2782  if (failed(parser.parseCommaSeparatedList([&]() {
2783  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2784  parser.parseColon() ||
2785  parser.parseAttribute(gangDimAttrs.emplace_back()))
2786  return failure();
2787  if (succeeded(parser.parseOptionalLSquare())) {
2788  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2789  parser.parseRSquare())
2790  return failure();
2791  } else {
2792  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2793  parser.getContext(), mlir::acc::DeviceType::None));
2794  }
2795  return success();
2796  })))
2797  return failure();
2798 
2799  if (failed(parser.parseRParen()))
2800  return failure();
2801 
2802  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2803  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2804  gangDimDeviceTypes =
2805  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2806 
2807  return success();
2808 }
2809 
2811  std::optional<mlir::ArrayAttr> gang,
2812  std::optional<mlir::ArrayAttr> gangDim,
2813  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2814 
2815  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2816  gang->size() == 1) {
2817  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2818  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2819  return;
2820  }
2821 
2822  p << "(";
2823 
2824  printDeviceTypes(p, gang);
2825 
2826  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2827  p << ", ";
2828 
2829  if (hasDeviceTypeValues(gangDimDeviceTypes))
2830  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2831  [&](const auto &pair) {
2832  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2833  p << std::get<0>(pair);
2834  printSingleDeviceType(p, std::get<1>(pair));
2835  });
2836 
2837  p << ")";
2838 }
2839 
2840 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2841  mlir::ArrayAttr &deviceTypes) {
2843  // Keyword only
2844  if (failed(parser.parseOptionalLParen())) {
2845  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2847  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2848  return success();
2849  }
2850 
2851  // Parse device type attributes
2852  if (succeeded(parser.parseOptionalLSquare())) {
2853  if (failed(parser.parseCommaSeparatedList([&]() {
2854  if (parser.parseAttribute(attributes.emplace_back()))
2855  return failure();
2856  return success();
2857  })))
2858  return failure();
2859  if (parser.parseRSquare() || parser.parseRParen())
2860  return failure();
2861  }
2862  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2863  return success();
2864 }
2865 
2866 static void
2868  std::optional<mlir::ArrayAttr> deviceTypes) {
2869 
2870  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2871  auto deviceTypeAttr =
2872  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2873  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2874  return;
2875  }
2876 
2877  if (!hasDeviceTypeValues(deviceTypes))
2878  return;
2879 
2880  p << "([";
2881  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2882  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2883  p << dTypeAttr;
2884  });
2885  p << "])";
2886 }
2887 
2888 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2889 
2890 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2891  return hasDeviceType(getWorker(), deviceType);
2892 }
2893 
2894 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2895 
2896 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2897  return hasDeviceType(getVector(), deviceType);
2898 }
2899 
2900 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2901 
2902 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2903  return hasDeviceType(getSeq(), deviceType);
2904 }
2905 
2906 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2907  return getBindNameValue(mlir::acc::DeviceType::None);
2908 }
2909 
2910 std::optional<llvm::StringRef>
2911 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2912  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2913  return std::nullopt;
2914  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2915  auto attr = (*getBindName())[*pos];
2916  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2917  return stringAttr.getValue();
2918  }
2919  return std::nullopt;
2920 }
2921 
2922 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2923 
2924 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2925  return hasDeviceType(getGang(), deviceType);
2926 }
2927 
2928 std::optional<int64_t> RoutineOp::getGangDimValue() {
2929  return getGangDimValue(mlir::acc::DeviceType::None);
2930 }
2931 
2932 std::optional<int64_t>
2933 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2934  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2935  return std::nullopt;
2936  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2937  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2938  return intAttr.getInt();
2939  }
2940  return std::nullopt;
2941 }
2942 
2943 //===----------------------------------------------------------------------===//
2944 // InitOp
2945 //===----------------------------------------------------------------------===//
2946 
2947 LogicalResult acc::InitOp::verify() {
2948  Operation *currOp = *this;
2949  while ((currOp = currOp->getParentOp()))
2950  if (isComputeOperation(currOp))
2951  return emitOpError("cannot be nested in a compute operation");
2952  return success();
2953 }
2954 
2955 //===----------------------------------------------------------------------===//
2956 // ShutdownOp
2957 //===----------------------------------------------------------------------===//
2958 
2959 LogicalResult acc::ShutdownOp::verify() {
2960  Operation *currOp = *this;
2961  while ((currOp = currOp->getParentOp()))
2962  if (isComputeOperation(currOp))
2963  return emitOpError("cannot be nested in a compute operation");
2964  return success();
2965 }
2966 
2967 //===----------------------------------------------------------------------===//
2968 // SetOp
2969 //===----------------------------------------------------------------------===//
2970 
2971 LogicalResult acc::SetOp::verify() {
2972  Operation *currOp = *this;
2973  while ((currOp = currOp->getParentOp()))
2974  if (isComputeOperation(currOp))
2975  return emitOpError("cannot be nested in a compute operation");
2976  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2977  return emitOpError("at least one default_async, device_num, or device_type "
2978  "operand must appear");
2979  return success();
2980 }
2981 
2982 //===----------------------------------------------------------------------===//
2983 // UpdateOp
2984 //===----------------------------------------------------------------------===//
2985 
2986 LogicalResult acc::UpdateOp::verify() {
2987  // At least one of host or device should have a value.
2988  if (getDataClauseOperands().empty())
2989  return emitError("at least one value must be present in dataOperands");
2990 
2991  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2992  getAsyncOperandsDeviceTypeAttr(),
2993  "async")))
2994  return failure();
2995 
2997  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2998  getWaitOperandsDeviceTypeAttr(), "wait")))
2999  return failure();
3000 
3001  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
3002  return failure();
3003 
3004  for (mlir::Value operand : getDataClauseOperands())
3005  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3006  operand.getDefiningOp()))
3007  return emitError("expect data entry/exit operation or acc.getdeviceptr "
3008  "as defining op");
3009 
3010  return success();
3011 }
3012 
3013 unsigned UpdateOp::getNumDataOperands() {
3014  return getDataClauseOperands().size();
3015 }
3016 
3017 Value UpdateOp::getDataOperand(unsigned i) {
3018  unsigned numOptional = getAsyncOperands().size();
3019  numOptional += getIfCond() ? 1 : 0;
3020  return getOperand(getWaitOperands().size() + numOptional + i);
3021 }
3022 
3023 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3024  MLIRContext *context) {
3025  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
3026 }
3027 
3028 bool UpdateOp::hasAsyncOnly() {
3029  return hasAsyncOnly(mlir::acc::DeviceType::None);
3030 }
3031 
3032 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3033  return hasDeviceType(getAsync(), deviceType);
3034 }
3035 
3036 mlir::Value UpdateOp::getAsyncValue() {
3037  return getAsyncValue(mlir::acc::DeviceType::None);
3038 }
3039 
3040 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3042  return {};
3043 
3044  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3045  return getAsyncOperands()[*pos];
3046 
3047  return {};
3048 }
3049 
3050 bool UpdateOp::hasWaitOnly() {
3051  return hasWaitOnly(mlir::acc::DeviceType::None);
3052 }
3053 
3054 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3055  return hasDeviceType(getWaitOnly(), deviceType);
3056 }
3057 
3058 mlir::Operation::operand_range UpdateOp::getWaitValues() {
3059  return getWaitValues(mlir::acc::DeviceType::None);
3060 }
3061 
3063 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3065  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3066  getHasWaitDevnum(), deviceType);
3067 }
3068 
3069 mlir::Value UpdateOp::getWaitDevnum() {
3070  return getWaitDevnum(mlir::acc::DeviceType::None);
3071 }
3072 
3073 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3074  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3075  getWaitOperandsSegments(), getHasWaitDevnum(),
3076  deviceType);
3077 }
3078 
3079 //===----------------------------------------------------------------------===//
3080 // WaitOp
3081 //===----------------------------------------------------------------------===//
3082 
3083 LogicalResult acc::WaitOp::verify() {
3084  // The async attribute represent the async clause without value. Therefore the
3085  // attribute and operand cannot appear at the same time.
3086  if (getAsyncOperand() && getAsync())
3087  return emitError("async attribute cannot appear with asyncOperand");
3088 
3089  if (getWaitDevnum() && getWaitOperands().empty())
3090  return emitError("wait_devnum cannot appear without waitOperands");
3091 
3092  return success();
3093 }
3094 
3095 #define GET_OP_CLASSES
3096 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3097 
3098 #define GET_ATTRDEF_CLASSES
3099 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3100 
3101 #define GET_TYPEDEF_CLASSES
3102 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3103 
3104 //===----------------------------------------------------------------------===//
3105 // acc dialect utilities
3106 //===----------------------------------------------------------------------===//
3107 
3110  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3112  accDataClauseOp)
3113  .Case<ACC_DATA_ENTRY_OPS>(
3114  [&](auto entry) { return entry.getVarPtr(); })
3115  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3116  [&](auto exit) { return exit.getVarPtr(); })
3117  .Default([&](mlir::Operation *) {
3119  })};
3120  return varPtr;
3121 }
3122 
3124  auto varPtr{
3126  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3127  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3128  return varPtr;
3129 }
3130 
3132  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3133  .Case<ACC_DATA_ENTRY_OPS>(
3134  [&](auto entry) { return entry.getVarType(); })
3135  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3136  [&](auto exit) { return exit.getVarType(); })
3137  .Default([&](mlir::Operation *) { return mlir::Type(); })};
3138  return varType;
3139 }
3140 
3143  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3145  accDataClauseOp)
3146  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3147  [&](auto dataClause) { return dataClause.getAccPtr(); })
3148  .Default([&](mlir::Operation *) {
3150  })};
3151  return accPtr;
3152 }
3153 
3155  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3157  [&](auto dataClause) { return dataClause.getAccVar(); })
3158  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3159  return accPtr;
3160 }
3161 
3163  auto varPtrPtr{
3165  .Case<ACC_DATA_ENTRY_OPS>(
3166  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3167  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3168  return varPtrPtr;
3169 }
3170 
3175  accDataClauseOp)
3176  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3178  dataClause.getBounds().begin(), dataClause.getBounds().end());
3179  })
3180  .Default([&](mlir::Operation *) {
3182  })};
3183  return bounds;
3184 }
3185 
3189  accDataClauseOp)
3190  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3192  dataClause.getAsyncOperands().begin(),
3193  dataClause.getAsyncOperands().end());
3194  })
3195  .Default([&](mlir::Operation *) {
3197  });
3198 }
3199 
3200 mlir::ArrayAttr
3203  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3204  return dataClause.getAsyncOperandsDeviceTypeAttr();
3205  })
3206  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3207 }
3208 
3209 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3212  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3213  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3214 }
3215 
3216 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3217  auto name{
3219  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3220  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3221  return {};
3222  })};
3223  return name;
3224 }
3225 
3226 std::optional<mlir::acc::DataClause>
3228  auto dataClause{
3230  accDataEntryOp)
3231  .Case<ACC_DATA_ENTRY_OPS>(
3232  [&](auto entry) { return entry.getDataClause(); })
3233  .Default([&](mlir::Operation *) { return std::nullopt; })};
3234  return dataClause;
3235 }
3236 
3238  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3239  .Case<ACC_DATA_ENTRY_OPS>(
3240  [&](auto entry) { return entry.getImplicit(); })
3241  .Default([&](mlir::Operation *) { return false; })};
3242  return implicit;
3243 }
3244 
3246  auto dataOperands{
3249  [&](auto entry) { return entry.getDataClauseOperands(); })
3250  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3251  return dataOperands;
3252 }
3253 
3256  auto dataOperands{
3259  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3260  .Default([&](mlir::Operation *) { return nullptr; })};
3261  return dataOperands;
3262 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2236
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2810
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:678
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:2029
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:980
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2043
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:692
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1387
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:306
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2713
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:275
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1398
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1303
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:111
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2867
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1838
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1552
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2598
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:253
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:131
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:228
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2314
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:909
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1432
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1062
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:284
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:155
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1173
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:260
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2345
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2840
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2749
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1286
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1459
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2739
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1240
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:893
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:187
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:347
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1857
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:768
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1217
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:142
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:924
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1532
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:117
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1984
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:171
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1470
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:318
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:207
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:990
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2680
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1223
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1578
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:873
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:44
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:52
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:205
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:815
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:3154
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3123
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3142
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:3227
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:3255
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:3172
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:3245
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3216
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:3237
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:3187
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:3162
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3209
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:167
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:3131
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3109
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3201
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
This represents an operation in an abstracted form, suitable for use with the builder APIs.