MLIR  20.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/LogicalResult.h"
23 
24 using namespace mlir;
25 using namespace acc;
26 
27 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
32 
33 namespace {
34 struct MemRefPointerLikeModel
35  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
36  MemRefType> {
37  Type getElementType(Type pointer) const {
38  return llvm::cast<MemRefType>(pointer).getElementType();
39  }
40 };
41 
42 struct LLVMPointerPointerLikeModel
43  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
44  LLVM::LLVMPointerType> {
45  Type getElementType(Type pointer) const { return Type(); }
46 };
47 } // namespace
48 
49 //===----------------------------------------------------------------------===//
50 // OpenACC operations
51 //===----------------------------------------------------------------------===//
52 
53 void OpenACCDialect::initialize() {
54  addOperations<
55 #define GET_OP_LIST
56 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
57  >();
58  addAttributes<
59 #define GET_ATTRDEF_LIST
60 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
61  >();
62  addTypes<
63 #define GET_TYPEDEF_LIST
64 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
65  >();
66 
67  // By attaching interfaces here, we make the OpenACC dialect dependent on
68  // the other dialects. This is probably better than having dialects like LLVM
69  // and memref be dependent on OpenACC.
70  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
71  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
72  *getContext());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // device_type support helpers
77 //===----------------------------------------------------------------------===//
78 
79 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
80  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
81  return true;
82  return false;
83 }
84 
85 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
86  mlir::acc::DeviceType deviceType) {
87  if (!hasDeviceTypeValues(arrayAttr))
88  return false;
89 
90  for (auto attr : *arrayAttr) {
91  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
92  if (deviceTypeAttr.getValue() == deviceType)
93  return true;
94  }
95 
96  return false;
97 }
98 
100  std::optional<mlir::ArrayAttr> deviceTypes) {
101  if (!hasDeviceTypeValues(deviceTypes))
102  return;
103 
104  p << "[";
105  llvm::interleaveComma(*deviceTypes, p,
106  [&](mlir::Attribute attr) { p << attr; });
107  p << "]";
108 }
109 
110 static std::optional<unsigned> findSegment(ArrayAttr segments,
111  mlir::acc::DeviceType deviceType) {
112  unsigned segmentIdx = 0;
113  for (auto attr : segments) {
114  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
115  if (deviceTypeAttr.getValue() == deviceType)
116  return std::make_optional(segmentIdx);
117  ++segmentIdx;
118  }
119  return std::nullopt;
120 }
121 
123 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
125  std::optional<llvm::ArrayRef<int32_t>> segments,
126  mlir::acc::DeviceType deviceType) {
127  if (!arrayAttr)
128  return range.take_front(0);
129  if (auto pos = findSegment(*arrayAttr, deviceType)) {
130  int32_t nbOperandsBefore = 0;
131  for (unsigned i = 0; i < *pos; ++i)
132  nbOperandsBefore += (*segments)[i];
133  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
134  }
135  return range.take_front(0);
136 }
137 
138 static mlir::Value
139 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
141  std::optional<llvm::ArrayRef<int32_t>> segments,
142  std::optional<mlir::ArrayAttr> hasWaitDevnum,
143  mlir::acc::DeviceType deviceType) {
144  if (!hasDeviceTypeValues(deviceTypeAttr))
145  return {};
146  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
147  if (hasWaitDevnum->getValue()[*pos])
148  return getValuesFromSegments(deviceTypeAttr, operands, segments,
149  deviceType)
150  .front();
151  return {};
152 }
153 
155 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
157  std::optional<llvm::ArrayRef<int32_t>> segments,
158  std::optional<mlir::ArrayAttr> hasWaitDevnum,
159  mlir::acc::DeviceType deviceType) {
160  auto range =
161  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
162  if (range.empty())
163  return range;
164  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
165  if (hasWaitDevnum && *hasWaitDevnum) {
166  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
167  if (boolAttr.getValue())
168  return range.drop_front(1); // first value is devnum
169  }
170  }
171  return range;
172 }
173 
174 template <typename Op>
175 static LogicalResult checkWaitAndAsyncConflict(Op op) {
176  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
177  ++dtypeInt) {
178  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
179 
180  // The async attribute represent the async clause without value. Therefore
181  // the attribute and operand cannot appear at the same time.
182  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
183  op.hasAsyncOnly(dtype))
184  return op.emitError("async attribute cannot appear with asyncOperand");
185 
186  // The wait attribute represent the wait clause without values. Therefore
187  // the attribute and operands cannot appear at the same time.
188  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
189  op.hasWaitOnly(dtype))
190  return op.emitError("wait attribute cannot appear with waitOperands");
191  }
192  return success();
193 }
194 
195 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
196  mlir::Type &varPtrType,
197  mlir::TypeAttr &varTypeAttr) {
198  if (failed(parser.parseType(varPtrType)))
199  return failure();
200  if (failed(parser.parseRParen()))
201  return failure();
202 
203  if (succeeded(parser.parseOptionalKeyword("varType"))) {
204  if (failed(parser.parseLParen()))
205  return failure();
206  mlir::Type varType;
207  if (failed(parser.parseType(varType)))
208  return failure();
209  varTypeAttr = mlir::TypeAttr::get(varType);
210  if (failed(parser.parseRParen()))
211  return failure();
212  } else {
213  // Set `varType` from the element type of the type of `varPtr`.
214  varTypeAttr = mlir::TypeAttr::get(
215  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
216  }
217 
218  return success();
219 }
220 
222  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
223  p.printType(varPtrType);
224  p << ")";
225 
226  // Print the `varType` only if it differs from the element type of
227  // `varPtr`'s type.
228  mlir::Type varType = varTypeAttr.getValue();
229  if (mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() !=
230  varType) {
231  p << " varType(";
232  p.printType(varType);
233  p << ")";
234  }
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // DataBoundsOp
239 //===----------------------------------------------------------------------===//
240 LogicalResult acc::DataBoundsOp::verify() {
241  auto extent = getExtent();
242  auto upperbound = getUpperbound();
243  if (!extent && !upperbound)
244  return emitError("expected extent or upperbound.");
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // PrivateOp
250 //===----------------------------------------------------------------------===//
251 LogicalResult acc::PrivateOp::verify() {
252  if (getDataClause() != acc::DataClause::acc_private)
253  return emitError(
254  "data clause associated with private operation must match its intent");
255  return success();
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // FirstprivateOp
260 //===----------------------------------------------------------------------===//
261 LogicalResult acc::FirstprivateOp::verify() {
262  if (getDataClause() != acc::DataClause::acc_firstprivate)
263  return emitError("data clause associated with firstprivate operation must "
264  "match its intent");
265  return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // ReductionOp
270 //===----------------------------------------------------------------------===//
271 LogicalResult acc::ReductionOp::verify() {
272  if (getDataClause() != acc::DataClause::acc_reduction)
273  return emitError("data clause associated with reduction operation must "
274  "match its intent");
275  return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // DevicePtrOp
280 //===----------------------------------------------------------------------===//
281 LogicalResult acc::DevicePtrOp::verify() {
282  if (getDataClause() != acc::DataClause::acc_deviceptr)
283  return emitError("data clause associated with deviceptr operation must "
284  "match its intent");
285  return success();
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // PresentOp
290 //===----------------------------------------------------------------------===//
291 LogicalResult acc::PresentOp::verify() {
292  if (getDataClause() != acc::DataClause::acc_present)
293  return emitError(
294  "data clause associated with present operation must match its intent");
295  return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // CopyinOp
300 //===----------------------------------------------------------------------===//
301 LogicalResult acc::CopyinOp::verify() {
302  // Test for all clauses this operation can be decomposed from:
303  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
304  getDataClause() != acc::DataClause::acc_copyin_readonly &&
305  getDataClause() != acc::DataClause::acc_copy &&
306  getDataClause() != acc::DataClause::acc_reduction)
307  return emitError(
308  "data clause associated with copyin operation must match its intent"
309  " or specify original clause this operation was decomposed from");
310  return success();
311 }
312 
313 bool acc::CopyinOp::isCopyinReadonly() {
314  return getDataClause() == acc::DataClause::acc_copyin_readonly;
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // CreateOp
319 //===----------------------------------------------------------------------===//
320 LogicalResult acc::CreateOp::verify() {
321  // Test for all clauses this operation can be decomposed from:
322  if (getDataClause() != acc::DataClause::acc_create &&
323  getDataClause() != acc::DataClause::acc_create_zero &&
324  getDataClause() != acc::DataClause::acc_copyout &&
325  getDataClause() != acc::DataClause::acc_copyout_zero)
326  return emitError(
327  "data clause associated with create operation must match its intent"
328  " or specify original clause this operation was decomposed from");
329  return success();
330 }
331 
332 bool acc::CreateOp::isCreateZero() {
333  // The zero modifier is encoded in the data clause.
334  return getDataClause() == acc::DataClause::acc_create_zero ||
335  getDataClause() == acc::DataClause::acc_copyout_zero;
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // NoCreateOp
340 //===----------------------------------------------------------------------===//
341 LogicalResult acc::NoCreateOp::verify() {
342  if (getDataClause() != acc::DataClause::acc_no_create)
343  return emitError("data clause associated with no_create operation must "
344  "match its intent");
345  return success();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // AttachOp
350 //===----------------------------------------------------------------------===//
351 LogicalResult acc::AttachOp::verify() {
352  if (getDataClause() != acc::DataClause::acc_attach)
353  return emitError(
354  "data clause associated with attach operation must match its intent");
355  return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // DeclareDeviceResidentOp
360 //===----------------------------------------------------------------------===//
361 
362 LogicalResult acc::DeclareDeviceResidentOp::verify() {
363  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
364  return emitError("data clause associated with device_resident operation "
365  "must match its intent");
366  return success();
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // DeclareLinkOp
371 //===----------------------------------------------------------------------===//
372 
373 LogicalResult acc::DeclareLinkOp::verify() {
374  if (getDataClause() != acc::DataClause::acc_declare_link)
375  return emitError(
376  "data clause associated with link operation must match its intent");
377  return success();
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // CopyoutOp
382 //===----------------------------------------------------------------------===//
383 LogicalResult acc::CopyoutOp::verify() {
384  // Test for all clauses this operation can be decomposed from:
385  if (getDataClause() != acc::DataClause::acc_copyout &&
386  getDataClause() != acc::DataClause::acc_copyout_zero &&
387  getDataClause() != acc::DataClause::acc_copy &&
388  getDataClause() != acc::DataClause::acc_reduction)
389  return emitError(
390  "data clause associated with copyout operation must match its intent"
391  " or specify original clause this operation was decomposed from");
392  if (!getVarPtr() || !getAccPtr())
393  return emitError("must have both host and device pointers");
394  return success();
395 }
396 
397 bool acc::CopyoutOp::isCopyoutZero() {
398  return getDataClause() == acc::DataClause::acc_copyout_zero;
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // DeleteOp
403 //===----------------------------------------------------------------------===//
404 LogicalResult acc::DeleteOp::verify() {
405  // Test for all clauses this operation can be decomposed from:
406  if (getDataClause() != acc::DataClause::acc_delete &&
407  getDataClause() != acc::DataClause::acc_create &&
408  getDataClause() != acc::DataClause::acc_create_zero &&
409  getDataClause() != acc::DataClause::acc_copyin &&
410  getDataClause() != acc::DataClause::acc_copyin_readonly &&
411  getDataClause() != acc::DataClause::acc_present &&
412  getDataClause() != acc::DataClause::acc_declare_device_resident &&
413  getDataClause() != acc::DataClause::acc_declare_link)
414  return emitError(
415  "data clause associated with delete operation must match its intent"
416  " or specify original clause this operation was decomposed from");
417  if (!getAccPtr())
418  return emitError("must have device pointer");
419  return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // DetachOp
424 //===----------------------------------------------------------------------===//
425 LogicalResult acc::DetachOp::verify() {
426  // Test for all clauses this operation can be decomposed from:
427  if (getDataClause() != acc::DataClause::acc_detach &&
428  getDataClause() != acc::DataClause::acc_attach)
429  return emitError(
430  "data clause associated with detach operation must match its intent"
431  " or specify original clause this operation was decomposed from");
432  if (!getAccPtr())
433  return emitError("must have device pointer");
434  return success();
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // HostOp
439 //===----------------------------------------------------------------------===//
440 LogicalResult acc::UpdateHostOp::verify() {
441  // Test for all clauses this operation can be decomposed from:
442  if (getDataClause() != acc::DataClause::acc_update_host &&
443  getDataClause() != acc::DataClause::acc_update_self)
444  return emitError(
445  "data clause associated with host operation must match its intent"
446  " or specify original clause this operation was decomposed from");
447  if (!getVarPtr() || !getAccPtr())
448  return emitError("must have both host and device pointers");
449  return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // DeviceOp
454 //===----------------------------------------------------------------------===//
455 LogicalResult acc::UpdateDeviceOp::verify() {
456  // Test for all clauses this operation can be decomposed from:
457  if (getDataClause() != acc::DataClause::acc_update_device)
458  return emitError(
459  "data clause associated with device operation must match its intent"
460  " or specify original clause this operation was decomposed from");
461  return success();
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // UseDeviceOp
466 //===----------------------------------------------------------------------===//
467 LogicalResult acc::UseDeviceOp::verify() {
468  // Test for all clauses this operation can be decomposed from:
469  if (getDataClause() != acc::DataClause::acc_use_device)
470  return emitError(
471  "data clause associated with use_device operation must match its intent"
472  " or specify original clause this operation was decomposed from");
473  return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // CacheOp
478 //===----------------------------------------------------------------------===//
479 LogicalResult acc::CacheOp::verify() {
480  // Test for all clauses this operation can be decomposed from:
481  if (getDataClause() != acc::DataClause::acc_cache &&
482  getDataClause() != acc::DataClause::acc_cache_readonly)
483  return emitError(
484  "data clause associated with cache operation must match its intent"
485  " or specify original clause this operation was decomposed from");
486  return success();
487 }
488 
489 template <typename StructureOp>
490 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
491  unsigned nRegions = 1) {
492 
493  SmallVector<Region *, 2> regions;
494  for (unsigned i = 0; i < nRegions; ++i)
495  regions.push_back(state.addRegion());
496 
497  for (Region *region : regions)
498  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
499  return failure();
500 
501  return success();
502 }
503 
504 static bool isComputeOperation(Operation *op) {
505  return isa<acc::ParallelOp, acc::LoopOp>(op);
506 }
507 
508 namespace {
509 /// Pattern to remove operation without region that have constant false `ifCond`
510 /// and remove the condition from the operation if the `ifCond` is a true
511 /// constant.
512 template <typename OpTy>
513 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
515 
516  LogicalResult matchAndRewrite(OpTy op,
517  PatternRewriter &rewriter) const override {
518  // Early return if there is no condition.
519  Value ifCond = op.getIfCond();
520  if (!ifCond)
521  return failure();
522 
523  IntegerAttr constAttr;
524  if (!matchPattern(ifCond, m_Constant(&constAttr)))
525  return failure();
526  if (constAttr.getInt())
527  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
528  else
529  rewriter.eraseOp(op);
530 
531  return success();
532  }
533 };
534 
535 /// Replaces the given op with the contents of the given single-block region,
536 /// using the operands of the block terminator to replace operation results.
537 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
538  Region &region, ValueRange blockArgs = {}) {
539  assert(llvm::hasSingleElement(region) && "expected single-region block");
540  Block *block = &region.front();
541  Operation *terminator = block->getTerminator();
542  ValueRange results = terminator->getOperands();
543  rewriter.inlineBlockBefore(block, op, blockArgs);
544  rewriter.replaceOp(op, results);
545  rewriter.eraseOp(terminator);
546 }
547 
548 /// Pattern to remove operation with region that have constant false `ifCond`
549 /// and remove the condition from the operation if the `ifCond` is constant
550 /// true.
551 template <typename OpTy>
552 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
554 
555  LogicalResult matchAndRewrite(OpTy op,
556  PatternRewriter &rewriter) const override {
557  // Early return if there is no condition.
558  Value ifCond = op.getIfCond();
559  if (!ifCond)
560  return failure();
561 
562  IntegerAttr constAttr;
563  if (!matchPattern(ifCond, m_Constant(&constAttr)))
564  return failure();
565  if (constAttr.getInt())
566  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
567  else
568  replaceOpWithRegion(rewriter, op, op.getRegion());
569 
570  return success();
571  }
572 };
573 
574 } // namespace
575 
576 //===----------------------------------------------------------------------===//
577 // PrivateRecipeOp
578 //===----------------------------------------------------------------------===//
579 
580 static LogicalResult verifyInitLikeSingleArgRegion(
581  Operation *op, Region &region, StringRef regionType, StringRef regionName,
582  Type type, bool verifyYield, bool optional = false) {
583  if (optional && region.empty())
584  return success();
585 
586  if (region.empty())
587  return op->emitOpError() << "expects non-empty " << regionName << " region";
588  Block &firstBlock = region.front();
589  if (firstBlock.getNumArguments() < 1 ||
590  firstBlock.getArgument(0).getType() != type)
591  return op->emitOpError() << "expects " << regionName
592  << " region first "
593  "argument of the "
594  << regionType << " type";
595 
596  if (verifyYield) {
597  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
598  if (yieldOp.getOperands().size() != 1 ||
599  yieldOp.getOperands().getTypes()[0] != type)
600  return op->emitOpError() << "expects " << regionName
601  << " region to "
602  "yield a value of the "
603  << regionType << " type";
604  }
605  }
606  return success();
607 }
608 
609 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
610  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
611  "privatization", "init", getType(),
612  /*verifyYield=*/false)))
613  return failure();
615  *this, getDestroyRegion(), "privatization", "destroy", getType(),
616  /*verifyYield=*/false, /*optional=*/true)))
617  return failure();
618  return success();
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // FirstprivateRecipeOp
623 //===----------------------------------------------------------------------===//
624 
625 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
626  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
627  "privatization", "init", getType(),
628  /*verifyYield=*/false)))
629  return failure();
630 
631  if (getCopyRegion().empty())
632  return emitOpError() << "expects non-empty copy region";
633 
634  Block &firstBlock = getCopyRegion().front();
635  if (firstBlock.getNumArguments() < 2 ||
636  firstBlock.getArgument(0).getType() != getType())
637  return emitOpError() << "expects copy region with two arguments of the "
638  "privatization type";
639 
640  if (getDestroyRegion().empty())
641  return success();
642 
643  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
644  "privatization", "destroy",
645  getType(), /*verifyYield=*/false)))
646  return failure();
647 
648  return success();
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // ReductionRecipeOp
653 //===----------------------------------------------------------------------===//
654 
655 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
656  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
657  "init", getType(),
658  /*verifyYield=*/false)))
659  return failure();
660 
661  if (getCombinerRegion().empty())
662  return emitOpError() << "expects non-empty combiner region";
663 
664  Block &reductionBlock = getCombinerRegion().front();
665  if (reductionBlock.getNumArguments() < 2 ||
666  reductionBlock.getArgument(0).getType() != getType() ||
667  reductionBlock.getArgument(1).getType() != getType())
668  return emitOpError() << "expects combiner region with the first two "
669  << "arguments of the reduction type";
670 
671  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
672  if (yieldOp.getOperands().size() != 1 ||
673  yieldOp.getOperands().getTypes()[0] != getType())
674  return emitOpError() << "expects combiner region to yield a value "
675  "of the reduction type";
676  }
677 
678  return success();
679 }
680 
681 //===----------------------------------------------------------------------===//
682 // Custom parser and printer verifier for private clause
683 //===----------------------------------------------------------------------===//
684 
685 static ParseResult parseSymOperandList(
686  mlir::OpAsmParser &parser,
688  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
690  if (failed(parser.parseCommaSeparatedList([&]() {
691  if (parser.parseAttribute(attributes.emplace_back()) ||
692  parser.parseArrow() ||
693  parser.parseOperand(operands.emplace_back()) ||
694  parser.parseColonType(types.emplace_back()))
695  return failure();
696  return success();
697  })))
698  return failure();
699  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
700  attributes.end());
701  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
702  return success();
703 }
704 
706  mlir::OperandRange operands,
707  mlir::TypeRange types,
708  std::optional<mlir::ArrayAttr> attributes) {
709  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
710  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
711  << std::get<1>(it).getType();
712  });
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // ParallelOp
717 //===----------------------------------------------------------------------===//
718 
719 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
720 template <typename Op>
721 static LogicalResult checkDataOperands(Op op,
722  const mlir::ValueRange &operands) {
723  for (mlir::Value operand : operands)
724  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
725  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
726  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
727  operand.getDefiningOp()))
728  return op.emitError(
729  "expect data entry/exit operation or acc.getdeviceptr "
730  "as defining op");
731  return success();
732 }
733 
734 template <typename Op>
735 static LogicalResult
736 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
737  mlir::OperandRange operands, llvm::StringRef operandName,
738  llvm::StringRef symbolName, bool checkOperandType = true) {
739  if (!operands.empty()) {
740  if (!attributes || attributes->size() != operands.size())
741  return op->emitOpError()
742  << "expected as many " << symbolName << " symbol reference as "
743  << operandName << " operands";
744  } else {
745  if (attributes)
746  return op->emitOpError()
747  << "unexpected " << symbolName << " symbol reference";
748  return success();
749  }
750 
752  for (auto args : llvm::zip(operands, *attributes)) {
753  mlir::Value operand = std::get<0>(args);
754 
755  if (!set.insert(operand).second)
756  return op->emitOpError()
757  << operandName << " operand appears more than once";
758 
759  mlir::Type varType = operand.getType();
760  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
761  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
762  if (!decl)
763  return op->emitOpError()
764  << "expected symbol reference " << symbolRef << " to point to a "
765  << operandName << " declaration";
766 
767  if (checkOperandType && decl.getType() && decl.getType() != varType)
768  return op->emitOpError() << "expected " << operandName << " (" << varType
769  << ") to be the same type as " << operandName
770  << " declaration (" << decl.getType() << ")";
771  }
772 
773  return success();
774 }
775 
776 unsigned ParallelOp::getNumDataOperands() {
777  return getReductionOperands().size() + getPrivateOperands().size() +
778  getFirstprivateOperands().size() + getDataClauseOperands().size();
779 }
780 
781 Value ParallelOp::getDataOperand(unsigned i) {
782  unsigned numOptional = getAsyncOperands().size();
783  numOptional += getNumGangs().size();
784  numOptional += getNumWorkers().size();
785  numOptional += getVectorLength().size();
786  numOptional += getIfCond() ? 1 : 0;
787  numOptional += getSelfCond() ? 1 : 0;
788  return getOperand(getWaitOperands().size() + numOptional + i);
789 }
790 
791 template <typename Op>
792 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
793  ArrayAttr deviceTypes,
794  llvm::StringRef keyword) {
795  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
796  return op.emitOpError() << keyword << " operands count must match "
797  << keyword << " device_type count";
798  return success();
799 }
800 
801 template <typename Op>
803  Op op, OperandRange operands, DenseI32ArrayAttr segments,
804  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
805  std::size_t numOperandsInSegments = 0;
806  std::size_t nbOfSegments = 0;
807 
808  if (segments) {
809  for (auto segCount : segments.asArrayRef()) {
810  if (maxInSegment != 0 && segCount > maxInSegment)
811  return op.emitOpError() << keyword << " expects a maximum of "
812  << maxInSegment << " values per segment";
813  numOperandsInSegments += segCount;
814  ++nbOfSegments;
815  }
816  }
817 
818  if ((numOperandsInSegments != operands.size()) ||
819  (!deviceTypes && !operands.empty()))
820  return op.emitOpError()
821  << keyword << " operand count does not match count in segments";
822  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
823  return op.emitOpError()
824  << keyword << " segment count does not match device_type count";
825  return success();
826 }
827 
828 LogicalResult acc::ParallelOp::verify() {
829  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
830  *this, getPrivatizations(), getPrivateOperands(), "private",
831  "privatizations", /*checkOperandType=*/false)))
832  return failure();
833  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
834  *this, getFirstprivatizations(), getFirstprivateOperands(),
835  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
836  return failure();
837  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
838  *this, getReductionRecipes(), getReductionOperands(), "reduction",
839  "reductions", false)))
840  return failure();
841 
843  *this, getNumGangs(), getNumGangsSegmentsAttr(),
844  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
845  return failure();
846 
848  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
849  getWaitOperandsDeviceTypeAttr(), "wait")))
850  return failure();
851 
852  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
853  getNumWorkersDeviceTypeAttr(),
854  "num_workers")))
855  return failure();
856 
857  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
858  getVectorLengthDeviceTypeAttr(),
859  "vector_length")))
860  return failure();
861 
862  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
863  getAsyncOperandsDeviceTypeAttr(),
864  "async")))
865  return failure();
866 
867  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
868  return failure();
869 
870  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
871 }
872 
873 static mlir::Value
874 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
876  mlir::acc::DeviceType deviceType) {
877  if (!arrayAttr)
878  return {};
879  if (auto pos = findSegment(*arrayAttr, deviceType))
880  return range[*pos];
881  return {};
882 }
883 
884 bool acc::ParallelOp::hasAsyncOnly() {
885  return hasAsyncOnly(mlir::acc::DeviceType::None);
886 }
887 
888 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
889  return hasDeviceType(getAsyncOnly(), deviceType);
890 }
891 
892 mlir::Value acc::ParallelOp::getAsyncValue() {
893  return getAsyncValue(mlir::acc::DeviceType::None);
894 }
895 
896 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
898  getAsyncOperands(), deviceType);
899 }
900 
901 mlir::Value acc::ParallelOp::getNumWorkersValue() {
902  return getNumWorkersValue(mlir::acc::DeviceType::None);
903 }
904 
906 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
907  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
908  deviceType);
909 }
910 
911 mlir::Value acc::ParallelOp::getVectorLengthValue() {
912  return getVectorLengthValue(mlir::acc::DeviceType::None);
913 }
914 
916 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
917  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
918  getVectorLength(), deviceType);
919 }
920 
921 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
922  return getNumGangsValues(mlir::acc::DeviceType::None);
923 }
924 
926 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
927  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
928  getNumGangsSegments(), deviceType);
929 }
930 
931 bool acc::ParallelOp::hasWaitOnly() {
932  return hasWaitOnly(mlir::acc::DeviceType::None);
933 }
934 
935 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
936  return hasDeviceType(getWaitOnly(), deviceType);
937 }
938 
939 mlir::Operation::operand_range ParallelOp::getWaitValues() {
940  return getWaitValues(mlir::acc::DeviceType::None);
941 }
942 
944 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
946  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
947  getHasWaitDevnum(), deviceType);
948 }
949 
950 mlir::Value ParallelOp::getWaitDevnum() {
951  return getWaitDevnum(mlir::acc::DeviceType::None);
952 }
953 
954 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
955  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
956  getWaitOperandsSegments(), getHasWaitDevnum(),
957  deviceType);
958 }
959 
960 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
961  mlir::OperationState &odsState,
962  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
963  mlir::ValueRange vectorLength,
964  mlir::ValueRange asyncOperands,
965  mlir::ValueRange waitOperands, mlir::Value ifCond,
966  mlir::Value selfCond, mlir::ValueRange reductionOperands,
967  mlir::ValueRange gangPrivateOperands,
968  mlir::ValueRange gangFirstPrivateOperands,
969  mlir::ValueRange dataClauseOperands) {
970 
971  ParallelOp::build(
972  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
973  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
974  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
975  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
976  /*numGangsDeviceType=*/nullptr, numWorkers,
977  /*numWorkersDeviceType=*/nullptr, vectorLength,
978  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
979  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
980  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
981  /*firstprivatizations=*/nullptr, dataClauseOperands,
982  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
983 }
984 
985 static ParseResult parseNumGangs(
986  mlir::OpAsmParser &parser,
988  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
989  mlir::DenseI32ArrayAttr &segments) {
992 
993  do {
994  if (failed(parser.parseLBrace()))
995  return failure();
996 
997  int32_t crtOperandsSize = operands.size();
998  if (failed(parser.parseCommaSeparatedList(
1000  if (parser.parseOperand(operands.emplace_back()) ||
1001  parser.parseColonType(types.emplace_back()))
1002  return failure();
1003  return success();
1004  })))
1005  return failure();
1006  seg.push_back(operands.size() - crtOperandsSize);
1007 
1008  if (failed(parser.parseRBrace()))
1009  return failure();
1010 
1011  if (succeeded(parser.parseOptionalLSquare())) {
1012  if (parser.parseAttribute(attributes.emplace_back()) ||
1013  parser.parseRSquare())
1014  return failure();
1015  } else {
1016  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1018  }
1019  } while (succeeded(parser.parseOptionalComma()));
1020 
1021  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1022  attributes.end());
1023  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1024  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1025 
1026  return success();
1027 }
1028 
1030  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1031  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1032  p << " [" << attr << "]";
1033 }
1034 
1036  mlir::OperandRange operands, mlir::TypeRange types,
1037  std::optional<mlir::ArrayAttr> deviceTypes,
1038  std::optional<mlir::DenseI32ArrayAttr> segments) {
1039  unsigned opIdx = 0;
1040  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1041  p << "{";
1042  llvm::interleaveComma(
1043  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1044  p << operands[opIdx] << " : " << operands[opIdx].getType();
1045  ++opIdx;
1046  });
1047  p << "}";
1048  printSingleDeviceType(p, it.value());
1049  });
1050 }
1051 
1053  mlir::OpAsmParser &parser,
1055  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1056  mlir::DenseI32ArrayAttr &segments) {
1059 
1060  do {
1061  if (failed(parser.parseLBrace()))
1062  return failure();
1063 
1064  int32_t crtOperandsSize = operands.size();
1065 
1066  if (failed(parser.parseCommaSeparatedList(
1068  if (parser.parseOperand(operands.emplace_back()) ||
1069  parser.parseColonType(types.emplace_back()))
1070  return failure();
1071  return success();
1072  })))
1073  return failure();
1074 
1075  seg.push_back(operands.size() - crtOperandsSize);
1076 
1077  if (failed(parser.parseRBrace()))
1078  return failure();
1079 
1080  if (succeeded(parser.parseOptionalLSquare())) {
1081  if (parser.parseAttribute(attributes.emplace_back()) ||
1082  parser.parseRSquare())
1083  return failure();
1084  } else {
1085  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1087  }
1088  } while (succeeded(parser.parseOptionalComma()));
1089 
1090  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1091  attributes.end());
1092  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1093  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1094 
1095  return success();
1096 }
1097 
1100  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1101  std::optional<mlir::DenseI32ArrayAttr> segments) {
1102  unsigned opIdx = 0;
1103  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1104  p << "{";
1105  llvm::interleaveComma(
1106  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1107  p << operands[opIdx] << " : " << operands[opIdx].getType();
1108  ++opIdx;
1109  });
1110  p << "}";
1111  printSingleDeviceType(p, it.value());
1112  });
1113 }
1114 
1115 static ParseResult parseWaitClause(
1116  mlir::OpAsmParser &parser,
1118  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1119  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1120  mlir::ArrayAttr &keywordOnly) {
1121  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1123 
1124  bool needCommaBeforeOperands = false;
1125 
1126  // Keyword only
1127  if (failed(parser.parseOptionalLParen())) {
1128  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1130  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1131  return success();
1132  }
1133 
1134  // Parse keyword only attributes
1135  if (succeeded(parser.parseOptionalLSquare())) {
1136  if (failed(parser.parseCommaSeparatedList([&]() {
1137  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1138  return failure();
1139  return success();
1140  })))
1141  return failure();
1142  if (parser.parseRSquare())
1143  return failure();
1144  needCommaBeforeOperands = true;
1145  }
1146 
1147  if (needCommaBeforeOperands && failed(parser.parseComma()))
1148  return failure();
1149 
1150  do {
1151  if (failed(parser.parseLBrace()))
1152  return failure();
1153 
1154  int32_t crtOperandsSize = operands.size();
1155 
1156  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1157  if (failed(parser.parseColon()))
1158  return failure();
1159  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1160  } else {
1161  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1162  }
1163 
1164  if (failed(parser.parseCommaSeparatedList(
1166  if (parser.parseOperand(operands.emplace_back()) ||
1167  parser.parseColonType(types.emplace_back()))
1168  return failure();
1169  return success();
1170  })))
1171  return failure();
1172 
1173  seg.push_back(operands.size() - crtOperandsSize);
1174 
1175  if (failed(parser.parseRBrace()))
1176  return failure();
1177 
1178  if (succeeded(parser.parseOptionalLSquare())) {
1179  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1180  parser.parseRSquare())
1181  return failure();
1182  } else {
1183  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1185  }
1186  } while (succeeded(parser.parseOptionalComma()));
1187 
1188  if (failed(parser.parseRParen()))
1189  return failure();
1190 
1191  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1192  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1193  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1194  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1195 
1196  return success();
1197 }
1198 
1199 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1200  if (!hasDeviceTypeValues(attrs))
1201  return false;
1202  if (attrs->size() != 1)
1203  return false;
1204  if (auto deviceTypeAttr =
1205  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1206  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1207  return false;
1208 }
1209 
1211  mlir::OperandRange operands, mlir::TypeRange types,
1212  std::optional<mlir::ArrayAttr> deviceTypes,
1213  std::optional<mlir::DenseI32ArrayAttr> segments,
1214  std::optional<mlir::ArrayAttr> hasDevNum,
1215  std::optional<mlir::ArrayAttr> keywordOnly) {
1216 
1217  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1218  return;
1219 
1220  p << "(";
1221 
1222  printDeviceTypes(p, keywordOnly);
1223  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1224  p << ", ";
1225 
1226  unsigned opIdx = 0;
1227  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1228  p << "{";
1229  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1230  if (boolAttr && boolAttr.getValue())
1231  p << "devnum: ";
1232  llvm::interleaveComma(
1233  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1234  p << operands[opIdx] << " : " << operands[opIdx].getType();
1235  ++opIdx;
1236  });
1237  p << "}";
1238  printSingleDeviceType(p, it.value());
1239  });
1240 
1241  p << ")";
1242 }
1243 
1244 static ParseResult parseDeviceTypeOperands(
1245  mlir::OpAsmParser &parser,
1247  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1249  if (failed(parser.parseCommaSeparatedList([&]() {
1250  if (parser.parseOperand(operands.emplace_back()) ||
1251  parser.parseColonType(types.emplace_back()))
1252  return failure();
1253  if (succeeded(parser.parseOptionalLSquare())) {
1254  if (parser.parseAttribute(attributes.emplace_back()) ||
1255  parser.parseRSquare())
1256  return failure();
1257  } else {
1258  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1259  parser.getContext(), mlir::acc::DeviceType::None));
1260  }
1261  return success();
1262  })))
1263  return failure();
1264  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1265  attributes.end());
1266  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1267  return success();
1268 }
1269 
1270 static void
1272  mlir::OperandRange operands, mlir::TypeRange types,
1273  std::optional<mlir::ArrayAttr> deviceTypes) {
1274  if (!hasDeviceTypeValues(deviceTypes))
1275  return;
1276  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1277  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1278  printSingleDeviceType(p, std::get<0>(it));
1279  });
1280 }
1281 
1283  mlir::OpAsmParser &parser,
1285  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1286  mlir::ArrayAttr &keywordOnlyDeviceType) {
1287 
1288  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1289  bool needCommaBeforeOperands = false;
1290 
1291  if (failed(parser.parseOptionalLParen())) {
1292  // Keyword only
1293  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1295  keywordOnlyDeviceType =
1296  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1297  return success();
1298  }
1299 
1300  // Parse keyword only attributes
1301  if (succeeded(parser.parseOptionalLSquare())) {
1302  // Parse keyword only attributes
1303  if (failed(parser.parseCommaSeparatedList([&]() {
1304  if (parser.parseAttribute(
1305  keywordOnlyDeviceTypeAttributes.emplace_back()))
1306  return failure();
1307  return success();
1308  })))
1309  return failure();
1310  if (parser.parseRSquare())
1311  return failure();
1312  needCommaBeforeOperands = true;
1313  }
1314 
1315  if (needCommaBeforeOperands && failed(parser.parseComma()))
1316  return failure();
1317 
1319  if (failed(parser.parseCommaSeparatedList([&]() {
1320  if (parser.parseOperand(operands.emplace_back()) ||
1321  parser.parseColonType(types.emplace_back()))
1322  return failure();
1323  if (succeeded(parser.parseOptionalLSquare())) {
1324  if (parser.parseAttribute(attributes.emplace_back()) ||
1325  parser.parseRSquare())
1326  return failure();
1327  } else {
1328  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1329  parser.getContext(), mlir::acc::DeviceType::None));
1330  }
1331  return success();
1332  })))
1333  return failure();
1334 
1335  if (failed(parser.parseRParen()))
1336  return failure();
1337 
1338  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1339  attributes.end());
1340  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1341  return success();
1342 }
1343 
1346  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1347  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1348 
1349  if (operands.begin() == operands.end() &&
1350  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1351  return;
1352  }
1353 
1354  p << "(";
1355  printDeviceTypes(p, keywordOnlyDeviceTypes);
1356  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1357  hasDeviceTypeValues(deviceTypes))
1358  p << ", ";
1359  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1360  p << ")";
1361 }
1362 
1363 static ParseResult
1365  mlir::acc::CombinedConstructsTypeAttr &attr) {
1366  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1367  if (parser.parseLParen())
1368  return failure();
1369  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1371  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1372  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1374  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1375  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1377  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1378  } else {
1379  parser.emitError(parser.getCurrentLocation(),
1380  "expected compute construct name");
1381  return failure();
1382  }
1383  if (parser.parseRParen())
1384  return failure();
1385  }
1386  return success();
1387 }
1388 
1389 static void
1391  mlir::acc::CombinedConstructsTypeAttr attr) {
1392  if (attr) {
1393  switch (attr.getValue()) {
1394  case mlir::acc::CombinedConstructsType::KernelsLoop:
1395  p << "combined(kernels)";
1396  break;
1397  case mlir::acc::CombinedConstructsType::ParallelLoop:
1398  p << "combined(parallel)";
1399  break;
1400  case mlir::acc::CombinedConstructsType::SerialLoop:
1401  p << "combined(serial)";
1402  break;
1403  };
1404  }
1405 }
1406 
1407 //===----------------------------------------------------------------------===//
1408 // SerialOp
1409 //===----------------------------------------------------------------------===//
1410 
1411 unsigned SerialOp::getNumDataOperands() {
1412  return getReductionOperands().size() + getPrivateOperands().size() +
1413  getFirstprivateOperands().size() + getDataClauseOperands().size();
1414 }
1415 
1416 Value SerialOp::getDataOperand(unsigned i) {
1417  unsigned numOptional = getAsyncOperands().size();
1418  numOptional += getIfCond() ? 1 : 0;
1419  numOptional += getSelfCond() ? 1 : 0;
1420  return getOperand(getWaitOperands().size() + numOptional + i);
1421 }
1422 
1423 bool acc::SerialOp::hasAsyncOnly() {
1424  return hasAsyncOnly(mlir::acc::DeviceType::None);
1425 }
1426 
1427 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1428  return hasDeviceType(getAsyncOnly(), deviceType);
1429 }
1430 
1431 mlir::Value acc::SerialOp::getAsyncValue() {
1432  return getAsyncValue(mlir::acc::DeviceType::None);
1433 }
1434 
1435 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1437  getAsyncOperands(), deviceType);
1438 }
1439 
1440 bool acc::SerialOp::hasWaitOnly() {
1441  return hasWaitOnly(mlir::acc::DeviceType::None);
1442 }
1443 
1444 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1445  return hasDeviceType(getWaitOnly(), deviceType);
1446 }
1447 
1448 mlir::Operation::operand_range SerialOp::getWaitValues() {
1449  return getWaitValues(mlir::acc::DeviceType::None);
1450 }
1451 
1453 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1455  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1456  getHasWaitDevnum(), deviceType);
1457 }
1458 
1459 mlir::Value SerialOp::getWaitDevnum() {
1460  return getWaitDevnum(mlir::acc::DeviceType::None);
1461 }
1462 
1463 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1464  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1465  getWaitOperandsSegments(), getHasWaitDevnum(),
1466  deviceType);
1467 }
1468 
1469 LogicalResult acc::SerialOp::verify() {
1470  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1471  *this, getPrivatizations(), getPrivateOperands(), "private",
1472  "privatizations", /*checkOperandType=*/false)))
1473  return failure();
1474  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1475  *this, getFirstprivatizations(), getFirstprivateOperands(),
1476  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1477  return failure();
1478  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1479  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1480  "reductions", false)))
1481  return failure();
1482 
1484  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1485  getWaitOperandsDeviceTypeAttr(), "wait")))
1486  return failure();
1487 
1488  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1489  getAsyncOperandsDeviceTypeAttr(),
1490  "async")))
1491  return failure();
1492 
1493  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1494  return failure();
1495 
1496  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // KernelsOp
1501 //===----------------------------------------------------------------------===//
1502 
1503 unsigned KernelsOp::getNumDataOperands() {
1504  return getDataClauseOperands().size();
1505 }
1506 
1507 Value KernelsOp::getDataOperand(unsigned i) {
1508  unsigned numOptional = getAsyncOperands().size();
1509  numOptional += getWaitOperands().size();
1510  numOptional += getNumGangs().size();
1511  numOptional += getNumWorkers().size();
1512  numOptional += getVectorLength().size();
1513  numOptional += getIfCond() ? 1 : 0;
1514  numOptional += getSelfCond() ? 1 : 0;
1515  return getOperand(numOptional + i);
1516 }
1517 
1518 bool acc::KernelsOp::hasAsyncOnly() {
1519  return hasAsyncOnly(mlir::acc::DeviceType::None);
1520 }
1521 
1522 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1523  return hasDeviceType(getAsyncOnly(), deviceType);
1524 }
1525 
1526 mlir::Value acc::KernelsOp::getAsyncValue() {
1527  return getAsyncValue(mlir::acc::DeviceType::None);
1528 }
1529 
1530 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1532  getAsyncOperands(), deviceType);
1533 }
1534 
1535 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1536  return getNumWorkersValue(mlir::acc::DeviceType::None);
1537 }
1538 
1540 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1541  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1542  deviceType);
1543 }
1544 
1545 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1546  return getVectorLengthValue(mlir::acc::DeviceType::None);
1547 }
1548 
1550 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1551  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1552  getVectorLength(), deviceType);
1553 }
1554 
1555 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1556  return getNumGangsValues(mlir::acc::DeviceType::None);
1557 }
1558 
1560 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1561  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1562  getNumGangsSegments(), deviceType);
1563 }
1564 
1565 bool acc::KernelsOp::hasWaitOnly() {
1566  return hasWaitOnly(mlir::acc::DeviceType::None);
1567 }
1568 
1569 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1570  return hasDeviceType(getWaitOnly(), deviceType);
1571 }
1572 
1573 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1574  return getWaitValues(mlir::acc::DeviceType::None);
1575 }
1576 
1578 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1580  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1581  getHasWaitDevnum(), deviceType);
1582 }
1583 
1584 mlir::Value KernelsOp::getWaitDevnum() {
1585  return getWaitDevnum(mlir::acc::DeviceType::None);
1586 }
1587 
1588 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1589  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1590  getWaitOperandsSegments(), getHasWaitDevnum(),
1591  deviceType);
1592 }
1593 
1594 LogicalResult acc::KernelsOp::verify() {
1596  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1597  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1598  return failure();
1599 
1601  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1602  getWaitOperandsDeviceTypeAttr(), "wait")))
1603  return failure();
1604 
1605  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1606  getNumWorkersDeviceTypeAttr(),
1607  "num_workers")))
1608  return failure();
1609 
1610  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1611  getVectorLengthDeviceTypeAttr(),
1612  "vector_length")))
1613  return failure();
1614 
1615  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1616  getAsyncOperandsDeviceTypeAttr(),
1617  "async")))
1618  return failure();
1619 
1620  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1621  return failure();
1622 
1623  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // HostDataOp
1628 //===----------------------------------------------------------------------===//
1629 
1630 LogicalResult acc::HostDataOp::verify() {
1631  if (getDataClauseOperands().empty())
1632  return emitError("at least one operand must appear on the host_data "
1633  "operation");
1634 
1635  for (mlir::Value operand : getDataClauseOperands())
1636  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1637  return emitError("expect data entry operation as defining op");
1638  return success();
1639 }
1640 
1641 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1642  MLIRContext *context) {
1643  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1644 }
1645 
1646 //===----------------------------------------------------------------------===//
1647 // LoopOp
1648 //===----------------------------------------------------------------------===//
1649 
1650 static ParseResult parseGangValue(
1651  OpAsmParser &parser, llvm::StringRef keyword,
1654  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1655  bool &needCommaBetweenValues, bool &newValue) {
1656  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1657  if (parser.parseEqual())
1658  return failure();
1659  if (parser.parseOperand(operands.emplace_back()) ||
1660  parser.parseColonType(types.emplace_back()))
1661  return failure();
1662  attributes.push_back(gangArgType);
1663  needCommaBetweenValues = true;
1664  newValue = true;
1665  }
1666  return success();
1667 }
1668 
1669 static ParseResult parseGangClause(
1670  OpAsmParser &parser,
1672  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1673  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1674  mlir::ArrayAttr &gangOnlyDeviceType) {
1675  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1676  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1677  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1679  bool needCommaBetweenValues = false;
1680  bool needCommaBeforeOperands = false;
1681 
1682  if (failed(parser.parseOptionalLParen())) {
1683  // Gang only keyword
1684  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1686  gangOnlyDeviceType =
1687  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1688  return success();
1689  }
1690 
1691  // Parse gang only attributes
1692  if (succeeded(parser.parseOptionalLSquare())) {
1693  // Parse gang only attributes
1694  if (failed(parser.parseCommaSeparatedList([&]() {
1695  if (parser.parseAttribute(
1696  gangOnlyDeviceTypeAttributes.emplace_back()))
1697  return failure();
1698  return success();
1699  })))
1700  return failure();
1701  if (parser.parseRSquare())
1702  return failure();
1703  needCommaBeforeOperands = true;
1704  }
1705 
1706  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1707  mlir::acc::GangArgType::Num);
1708  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1709  mlir::acc::GangArgType::Dim);
1710  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1711  parser.getContext(), mlir::acc::GangArgType::Static);
1712 
1713  do {
1714  if (needCommaBeforeOperands) {
1715  needCommaBeforeOperands = false;
1716  continue;
1717  }
1718 
1719  if (failed(parser.parseLBrace()))
1720  return failure();
1721 
1722  int32_t crtOperandsSize = gangOperands.size();
1723  while (true) {
1724  bool newValue = false;
1725  bool needValue = false;
1726  if (needCommaBetweenValues) {
1727  if (succeeded(parser.parseOptionalComma()))
1728  needValue = true; // expect a new value after comma.
1729  else
1730  break;
1731  }
1732 
1733  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1734  gangOperands, gangOperandsType,
1735  gangArgTypeAttributes, argNum,
1736  needCommaBetweenValues, newValue)))
1737  return failure();
1738  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1739  gangOperands, gangOperandsType,
1740  gangArgTypeAttributes, argDim,
1741  needCommaBetweenValues, newValue)))
1742  return failure();
1743  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1744  gangOperands, gangOperandsType,
1745  gangArgTypeAttributes, argStatic,
1746  needCommaBetweenValues, newValue)))
1747  return failure();
1748 
1749  if (!newValue && needValue) {
1750  parser.emitError(parser.getCurrentLocation(),
1751  "new value expected after comma");
1752  return failure();
1753  }
1754 
1755  if (!newValue)
1756  break;
1757  }
1758 
1759  if (gangOperands.empty())
1760  return parser.emitError(
1761  parser.getCurrentLocation(),
1762  "expect at least one of num, dim or static values");
1763 
1764  if (failed(parser.parseRBrace()))
1765  return failure();
1766 
1767  if (succeeded(parser.parseOptionalLSquare())) {
1768  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1769  parser.parseRSquare())
1770  return failure();
1771  } else {
1772  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1774  }
1775 
1776  seg.push_back(gangOperands.size() - crtOperandsSize);
1777 
1778  } while (succeeded(parser.parseOptionalComma()));
1779 
1780  if (failed(parser.parseRParen()))
1781  return failure();
1782 
1783  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1784  gangArgTypeAttributes.end());
1785  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1786  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1787 
1789  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1790  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1791 
1792  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1793  return success();
1794 }
1795 
1797  mlir::OperandRange operands, mlir::TypeRange types,
1798  std::optional<mlir::ArrayAttr> gangArgTypes,
1799  std::optional<mlir::ArrayAttr> deviceTypes,
1800  std::optional<mlir::DenseI32ArrayAttr> segments,
1801  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1802 
1803  if (operands.begin() == operands.end() &&
1804  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1805  return;
1806  }
1807 
1808  p << "(";
1809 
1810  printDeviceTypes(p, gangOnlyDeviceTypes);
1811 
1812  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1813  hasDeviceTypeValues(deviceTypes))
1814  p << ", ";
1815 
1816  if (hasDeviceTypeValues(deviceTypes)) {
1817  unsigned opIdx = 0;
1818  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1819  p << "{";
1820  llvm::interleaveComma(
1821  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1822  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1823  (*gangArgTypes)[opIdx]);
1824  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1825  p << LoopOp::getGangNumKeyword();
1826  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1827  p << LoopOp::getGangDimKeyword();
1828  else if (gangArgTypeAttr.getValue() ==
1829  mlir::acc::GangArgType::Static)
1830  p << LoopOp::getGangStaticKeyword();
1831  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1832  ++opIdx;
1833  });
1834  p << "}";
1835  printSingleDeviceType(p, it.value());
1836  });
1837  }
1838  p << ")";
1839 }
1840 
1842  std::optional<mlir::ArrayAttr> segments,
1843  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1844  if (!segments)
1845  return false;
1846  for (auto attr : *segments) {
1847  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1848  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
1849  return true;
1850  }
1851  return false;
1852 }
1853 
1854 /// Check for duplicates in the DeviceType array attribute.
1855 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1856  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1857  if (!deviceTypes)
1858  return success();
1859  for (auto attr : deviceTypes) {
1860  auto deviceTypeAttr =
1861  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1862  if (!deviceTypeAttr)
1863  return failure();
1864  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
1865  return failure();
1866  }
1867  return success();
1868 }
1869 
1870 LogicalResult acc::LoopOp::verify() {
1871  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1872  (getUpperbound().size() != getInclusiveUpperbound()->size()))
1873  return emitError() << "inclusiveUpperbound size is expected to be the same"
1874  << " as upperbound size";
1875 
1876  // Check collapse
1877  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1878  return emitOpError() << "collapse device_type attr must be define when"
1879  << " collapse attr is present";
1880 
1881  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1882  getCollapseAttr().getValue().size() !=
1883  getCollapseDeviceTypeAttr().getValue().size())
1884  return emitOpError() << "collapse attribute count must match collapse"
1885  << " device_type count";
1886  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1887  return emitOpError()
1888  << "duplicate device_type found in collapseDeviceType attribute";
1889 
1890  // Check gang
1891  if (!getGangOperands().empty()) {
1892  if (!getGangOperandsArgType())
1893  return emitOpError() << "gangOperandsArgType attribute must be defined"
1894  << " when gang operands are present";
1895 
1896  if (getGangOperands().size() !=
1897  getGangOperandsArgTypeAttr().getValue().size())
1898  return emitOpError() << "gangOperandsArgType attribute count must match"
1899  << " gangOperands count";
1900  }
1901  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1902  return emitOpError() << "duplicate device_type found in gang attribute";
1903 
1905  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1906  getGangOperandsDeviceTypeAttr(), "gang")))
1907  return failure();
1908 
1909  // Check worker
1910  if (failed(checkDeviceTypes(getWorkerAttr())))
1911  return emitOpError() << "duplicate device_type found in worker attribute";
1912  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1913  return emitOpError() << "duplicate device_type found in "
1914  "workerNumOperandsDeviceType attribute";
1915  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1916  getWorkerNumOperandsDeviceTypeAttr(),
1917  "worker")))
1918  return failure();
1919 
1920  // Check vector
1921  if (failed(checkDeviceTypes(getVectorAttr())))
1922  return emitOpError() << "duplicate device_type found in vector attribute";
1923  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1924  return emitOpError() << "duplicate device_type found in "
1925  "vectorOperandsDeviceType attribute";
1926  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1927  getVectorOperandsDeviceTypeAttr(),
1928  "vector")))
1929  return failure();
1930 
1932  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1933  getTileOperandsDeviceTypeAttr(), "tile")))
1934  return failure();
1935 
1936  // auto, independent and seq attribute are mutually exclusive.
1937  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1938  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1939  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1940  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1941  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1942  << "\", " << getIndependentAttrName() << ", "
1943  << getSeqAttrName()
1944  << " can be present at the same time";
1945  }
1946 
1947  // Gang, worker and vector are incompatible with seq.
1948  if (getSeqAttr()) {
1949  for (auto attr : getSeqAttr()) {
1950  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1951  if (hasVector(deviceTypeAttr.getValue()) ||
1952  getVectorValue(deviceTypeAttr.getValue()) ||
1953  hasWorker(deviceTypeAttr.getValue()) ||
1954  getWorkerValue(deviceTypeAttr.getValue()) ||
1955  hasGang(deviceTypeAttr.getValue()) ||
1956  getGangValue(mlir::acc::GangArgType::Num,
1957  deviceTypeAttr.getValue()) ||
1958  getGangValue(mlir::acc::GangArgType::Dim,
1959  deviceTypeAttr.getValue()) ||
1960  getGangValue(mlir::acc::GangArgType::Static,
1961  deviceTypeAttr.getValue()))
1962  return emitError()
1963  << "gang, worker or vector cannot appear with the seq attr";
1964  }
1965  }
1966 
1967  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1968  *this, getPrivatizations(), getPrivateOperands(), "private",
1969  "privatizations", false)))
1970  return failure();
1971 
1972  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1973  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1974  "reductions", false)))
1975  return failure();
1976 
1977  if (getCombined().has_value() &&
1978  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1979  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1980  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1981  return emitError("unexpected combined constructs attribute");
1982  }
1983 
1984  // Check non-empty body().
1985  if (getRegion().empty())
1986  return emitError("expected non-empty body.");
1987 
1988  return success();
1989 }
1990 
1991 unsigned LoopOp::getNumDataOperands() {
1992  return getReductionOperands().size() + getPrivateOperands().size();
1993 }
1994 
1995 Value LoopOp::getDataOperand(unsigned i) {
1996  unsigned numOptional =
1997  getLowerbound().size() + getUpperbound().size() + getStep().size();
1998  numOptional += getGangOperands().size();
1999  numOptional += getVectorOperands().size();
2000  numOptional += getWorkerNumOperands().size();
2001  numOptional += getTileOperands().size();
2002  numOptional += getCacheOperands().size();
2003  return getOperand(numOptional + i);
2004 }
2005 
2006 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2007 
2008 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2009  return hasDeviceType(getAuto_(), deviceType);
2010 }
2011 
2012 bool LoopOp::hasIndependent() {
2013  return hasIndependent(mlir::acc::DeviceType::None);
2014 }
2015 
2016 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2017  return hasDeviceType(getIndependent(), deviceType);
2018 }
2019 
2020 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2021 
2022 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2023  return hasDeviceType(getSeq(), deviceType);
2024 }
2025 
2026 mlir::Value LoopOp::getVectorValue() {
2027  return getVectorValue(mlir::acc::DeviceType::None);
2028 }
2029 
2030 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2031  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2032  getVectorOperands(), deviceType);
2033 }
2034 
2035 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2036 
2037 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2038  return hasDeviceType(getVector(), deviceType);
2039 }
2040 
2041 mlir::Value LoopOp::getWorkerValue() {
2042  return getWorkerValue(mlir::acc::DeviceType::None);
2043 }
2044 
2045 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2046  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2047  getWorkerNumOperands(), deviceType);
2048 }
2049 
2050 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2051 
2052 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2053  return hasDeviceType(getWorker(), deviceType);
2054 }
2055 
2056 mlir::Operation::operand_range LoopOp::getTileValues() {
2057  return getTileValues(mlir::acc::DeviceType::None);
2058 }
2059 
2061 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2062  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2063  getTileOperandsSegments(), deviceType);
2064 }
2065 
2066 std::optional<int64_t> LoopOp::getCollapseValue() {
2067  return getCollapseValue(mlir::acc::DeviceType::None);
2068 }
2069 
2070 std::optional<int64_t>
2071 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2072  if (!getCollapseAttr())
2073  return std::nullopt;
2074  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2075  auto intAttr =
2076  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2077  return intAttr.getValue().getZExtValue();
2078  }
2079  return std::nullopt;
2080 }
2081 
2082 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2083  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2084 }
2085 
2086 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2087  mlir::acc::DeviceType deviceType) {
2088  if (getGangOperands().empty())
2089  return {};
2090  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2091  int32_t nbOperandsBefore = 0;
2092  for (unsigned i = 0; i < *pos; ++i)
2093  nbOperandsBefore += (*getGangOperandsSegments())[i];
2095  getGangOperands()
2096  .drop_front(nbOperandsBefore)
2097  .take_front((*getGangOperandsSegments())[*pos]);
2098 
2099  int32_t argTypeIdx = nbOperandsBefore;
2100  for (auto value : values) {
2101  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2102  (*getGangOperandsArgType())[argTypeIdx]);
2103  if (gangArgTypeAttr.getValue() == gangArgType)
2104  return value;
2105  ++argTypeIdx;
2106  }
2107  }
2108  return {};
2109 }
2110 
2111 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2112 
2113 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2114  return hasDeviceType(getGang(), deviceType);
2115 }
2116 
2117 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2118  return {&getRegion()};
2119 }
2120 
2121 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2122 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2123 /// `(` ssa-id-and-type-list `)`
2124 /// region
2125 ParseResult
2128  SmallVectorImpl<Type> &lowerboundType,
2130  SmallVectorImpl<Type> &upperboundType,
2132  SmallVectorImpl<Type> &stepType) {
2133 
2134  SmallVector<OpAsmParser::Argument> inductionVars;
2135  if (succeeded(
2136  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2137  if (parser.parseLParen() ||
2138  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2139  /*allowType=*/true) ||
2140  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2141  parser.parseOperandList(lowerbound, inductionVars.size(),
2143  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2144  parser.parseKeyword("to") || parser.parseLParen() ||
2145  parser.parseOperandList(upperbound, inductionVars.size(),
2147  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2148  parser.parseKeyword("step") || parser.parseLParen() ||
2149  parser.parseOperandList(step, inductionVars.size(),
2151  parser.parseColonTypeList(stepType) || parser.parseRParen())
2152  return failure();
2153  }
2154  return parser.parseRegion(region, inductionVars);
2155 }
2156 
2158  ValueRange lowerbound, TypeRange lowerboundType,
2159  ValueRange upperbound, TypeRange upperboundType,
2160  ValueRange steps, TypeRange stepType) {
2161  ValueRange regionArgs = region.front().getArguments();
2162  if (!regionArgs.empty()) {
2163  p << acc::LoopOp::getControlKeyword() << "(";
2164  llvm::interleaveComma(regionArgs, p,
2165  [&p](Value v) { p << v << " : " << v.getType(); });
2166  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2167  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2168  << " : " << stepType << ") ";
2169  }
2170  p.printRegion(region, /*printEntryBlockArgs=*/false);
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // DataOp
2175 //===----------------------------------------------------------------------===//
2176 
2177 LogicalResult acc::DataOp::verify() {
2178  // 2.6.5. Data Construct restriction
2179  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2180  // attach, or default clause must appear on a data construct.
2181  if (getOperands().empty() && !getDefaultAttr())
2182  return emitError("at least one operand or the default attribute "
2183  "must appear on the data operation");
2184 
2185  for (mlir::Value operand : getDataClauseOperands())
2186  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2187  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2188  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2189  operand.getDefiningOp()))
2190  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2191  "as defining op");
2192 
2193  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2194  return failure();
2195 
2196  return success();
2197 }
2198 
2199 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2200 
2201 Value DataOp::getDataOperand(unsigned i) {
2202  unsigned numOptional = getIfCond() ? 1 : 0;
2203  numOptional += getAsyncOperands().size() ? 1 : 0;
2204  numOptional += getWaitOperands().size();
2205  return getOperand(numOptional + i);
2206 }
2207 
2208 bool acc::DataOp::hasAsyncOnly() {
2209  return hasAsyncOnly(mlir::acc::DeviceType::None);
2210 }
2211 
2212 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2213  return hasDeviceType(getAsyncOnly(), deviceType);
2214 }
2215 
2216 mlir::Value DataOp::getAsyncValue() {
2217  return getAsyncValue(mlir::acc::DeviceType::None);
2218 }
2219 
2220 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2222  getAsyncOperands(), deviceType);
2223 }
2224 
2225 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2226 
2227 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2228  return hasDeviceType(getWaitOnly(), deviceType);
2229 }
2230 
2231 mlir::Operation::operand_range DataOp::getWaitValues() {
2232  return getWaitValues(mlir::acc::DeviceType::None);
2233 }
2234 
2236 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2238  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2239  getHasWaitDevnum(), deviceType);
2240 }
2241 
2242 mlir::Value DataOp::getWaitDevnum() {
2243  return getWaitDevnum(mlir::acc::DeviceType::None);
2244 }
2245 
2246 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2247  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2248  getWaitOperandsSegments(), getHasWaitDevnum(),
2249  deviceType);
2250 }
2251 
2252 //===----------------------------------------------------------------------===//
2253 // ExitDataOp
2254 //===----------------------------------------------------------------------===//
2255 
2256 LogicalResult acc::ExitDataOp::verify() {
2257  // 2.6.6. Data Exit Directive restriction
2258  // At least one copyout, delete, or detach clause must appear on an exit data
2259  // directive.
2260  if (getDataClauseOperands().empty())
2261  return emitError("at least one operand must be present in dataOperands on "
2262  "the exit data operation");
2263 
2264  // The async attribute represent the async clause without value. Therefore the
2265  // attribute and operand cannot appear at the same time.
2266  if (getAsyncOperand() && getAsync())
2267  return emitError("async attribute cannot appear with asyncOperand");
2268 
2269  // The wait attribute represent the wait clause without values. Therefore the
2270  // attribute and operands cannot appear at the same time.
2271  if (!getWaitOperands().empty() && getWait())
2272  return emitError("wait attribute cannot appear with waitOperands");
2273 
2274  if (getWaitDevnum() && getWaitOperands().empty())
2275  return emitError("wait_devnum cannot appear without waitOperands");
2276 
2277  return success();
2278 }
2279 
2280 unsigned ExitDataOp::getNumDataOperands() {
2281  return getDataClauseOperands().size();
2282 }
2283 
2284 Value ExitDataOp::getDataOperand(unsigned i) {
2285  unsigned numOptional = getIfCond() ? 1 : 0;
2286  numOptional += getAsyncOperand() ? 1 : 0;
2287  numOptional += getWaitDevnum() ? 1 : 0;
2288  return getOperand(getWaitOperands().size() + numOptional + i);
2289 }
2290 
2291 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2292  MLIRContext *context) {
2293  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2294 }
2295 
2296 //===----------------------------------------------------------------------===//
2297 // EnterDataOp
2298 //===----------------------------------------------------------------------===//
2299 
2300 LogicalResult acc::EnterDataOp::verify() {
2301  // 2.6.6. Data Enter Directive restriction
2302  // At least one copyin, create, or attach clause must appear on an enter data
2303  // directive.
2304  if (getDataClauseOperands().empty())
2305  return emitError("at least one operand must be present in dataOperands on "
2306  "the enter data operation");
2307 
2308  // The async attribute represent the async clause without value. Therefore the
2309  // attribute and operand cannot appear at the same time.
2310  if (getAsyncOperand() && getAsync())
2311  return emitError("async attribute cannot appear with asyncOperand");
2312 
2313  // The wait attribute represent the wait clause without values. Therefore the
2314  // attribute and operands cannot appear at the same time.
2315  if (!getWaitOperands().empty() && getWait())
2316  return emitError("wait attribute cannot appear with waitOperands");
2317 
2318  if (getWaitDevnum() && getWaitOperands().empty())
2319  return emitError("wait_devnum cannot appear without waitOperands");
2320 
2321  for (mlir::Value operand : getDataClauseOperands())
2322  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2323  operand.getDefiningOp()))
2324  return emitError("expect data entry operation as defining op");
2325 
2326  return success();
2327 }
2328 
2329 unsigned EnterDataOp::getNumDataOperands() {
2330  return getDataClauseOperands().size();
2331 }
2332 
2333 Value EnterDataOp::getDataOperand(unsigned i) {
2334  unsigned numOptional = getIfCond() ? 1 : 0;
2335  numOptional += getAsyncOperand() ? 1 : 0;
2336  numOptional += getWaitDevnum() ? 1 : 0;
2337  return getOperand(getWaitOperands().size() + numOptional + i);
2338 }
2339 
2340 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2341  MLIRContext *context) {
2342  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2343 }
2344 
2345 //===----------------------------------------------------------------------===//
2346 // AtomicReadOp
2347 //===----------------------------------------------------------------------===//
2348 
2349 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2350 
2351 //===----------------------------------------------------------------------===//
2352 // AtomicWriteOp
2353 //===----------------------------------------------------------------------===//
2354 
2355 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2356 
2357 //===----------------------------------------------------------------------===//
2358 // AtomicUpdateOp
2359 //===----------------------------------------------------------------------===//
2360 
2361 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2362  PatternRewriter &rewriter) {
2363  if (op.isNoOp()) {
2364  rewriter.eraseOp(op);
2365  return success();
2366  }
2367 
2368  if (Value writeVal = op.getWriteOpVal()) {
2369  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2370  return success();
2371  }
2372 
2373  return failure();
2374 }
2375 
2376 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2377 
2378 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2379 
2380 //===----------------------------------------------------------------------===//
2381 // AtomicCaptureOp
2382 //===----------------------------------------------------------------------===//
2383 
2384 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2385  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2386  return op;
2387  return dyn_cast<AtomicReadOp>(getSecondOp());
2388 }
2389 
2390 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2391  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2392  return op;
2393  return dyn_cast<AtomicWriteOp>(getSecondOp());
2394 }
2395 
2396 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2397  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2398  return op;
2399  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2400 }
2401 
2402 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2403 
2404 //===----------------------------------------------------------------------===//
2405 // DeclareEnterOp
2406 //===----------------------------------------------------------------------===//
2407 
2408 template <typename Op>
2409 static LogicalResult
2411  bool requireAtLeastOneOperand = true) {
2412  if (operands.empty() && requireAtLeastOneOperand)
2413  return emitError(
2414  op->getLoc(),
2415  "at least one operand must appear on the declare operation");
2416 
2417  for (mlir::Value operand : operands) {
2418  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2419  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2420  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2421  operand.getDefiningOp()))
2422  return op.emitError(
2423  "expect valid declare data entry operation or acc.getdeviceptr "
2424  "as defining op");
2425 
2426  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2427  assert(varPtr && "declare operands can only be data entry operations which "
2428  "must have varPtr");
2429  std::optional<mlir::acc::DataClause> dataClauseOptional{
2430  getDataClause(operand.getDefiningOp())};
2431  assert(dataClauseOptional.has_value() &&
2432  "declare operands can only be data entry operations which must have "
2433  "dataClause");
2434 
2435  // If varPtr has no defining op - there is nothing to check further.
2436  if (!varPtr.getDefiningOp())
2437  continue;
2438 
2439  // Check that the varPtr has a declare attribute.
2440  auto declareAttribute{
2441  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2442  if (!declareAttribute)
2443  return op.emitError(
2444  "expect declare attribute on variable in declare operation");
2445 
2446  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2447  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2448  return op.emitError(
2449  "expect matching declare attribute on variable in declare operation");
2450 
2451  // If the variable is marked with implicit attribute, the matching declare
2452  // data action must also be marked implicit. The reverse is not checked
2453  // since implicit data action may be inserted to do actions like updating
2454  // device copy, in which case the variable is not necessarily implicitly
2455  // declare'd.
2456  if (declAttr.getImplicit() &&
2457  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2458  return op.emitError(
2459  "implicitness must match between declare op and flag on variable");
2460  }
2461 
2462  return success();
2463 }
2464 
2465 LogicalResult acc::DeclareEnterOp::verify() {
2466  return checkDeclareOperands(*this, this->getDataClauseOperands());
2467 }
2468 
2469 //===----------------------------------------------------------------------===//
2470 // DeclareExitOp
2471 //===----------------------------------------------------------------------===//
2472 
2473 LogicalResult acc::DeclareExitOp::verify() {
2474  if (getToken())
2475  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2476  /*requireAtLeastOneOperand=*/false);
2477  return checkDeclareOperands(*this, this->getDataClauseOperands());
2478 }
2479 
2480 //===----------------------------------------------------------------------===//
2481 // DeclareOp
2482 //===----------------------------------------------------------------------===//
2483 
2484 LogicalResult acc::DeclareOp::verify() {
2485  return checkDeclareOperands(*this, this->getDataClauseOperands());
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // RoutineOp
2490 //===----------------------------------------------------------------------===//
2491 
2492 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2493  acc::DeviceType dtype) {
2494  unsigned parallelism = 0;
2495  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2496  parallelism += op.hasWorker(dtype) ? 1 : 0;
2497  parallelism += op.hasVector(dtype) ? 1 : 0;
2498  parallelism += op.hasSeq(dtype) ? 1 : 0;
2499  return parallelism;
2500 }
2501 
2502 LogicalResult acc::RoutineOp::verify() {
2503  unsigned baseParallelism =
2505 
2506  if (baseParallelism > 1)
2507  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2508  "be present at the same time";
2509 
2510  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2511  ++dtypeInt) {
2512  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2513  if (dtype == acc::DeviceType::None)
2514  continue;
2515  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2516 
2517  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2518  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2519  "be present at the same time";
2520  }
2521 
2522  return success();
2523 }
2524 
2525 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2526  mlir::ArrayAttr &deviceTypes) {
2527  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2528  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2529 
2530  if (failed(parser.parseCommaSeparatedList([&]() {
2531  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2532  return failure();
2533  if (failed(parser.parseOptionalLSquare())) {
2534  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2535  parser.getContext(), mlir::acc::DeviceType::None));
2536  } else {
2537  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2538  parser.parseRSquare())
2539  return failure();
2540  }
2541  return success();
2542  })))
2543  return failure();
2544 
2545  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2546  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2547 
2548  return success();
2549 }
2550 
2552  std::optional<mlir::ArrayAttr> bindName,
2553  std::optional<mlir::ArrayAttr> deviceTypes) {
2554  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2555  [&](const auto &pair) {
2556  p << std::get<0>(pair);
2557  printSingleDeviceType(p, std::get<1>(pair));
2558  });
2559 }
2560 
2561 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2562  mlir::ArrayAttr &gang,
2563  mlir::ArrayAttr &gangDim,
2564  mlir::ArrayAttr &gangDimDeviceTypes) {
2565 
2566  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2567  gangDimDeviceTypeAttrs;
2568  bool needCommaBeforeOperands = false;
2569 
2570  // Gang keyword only
2571  if (failed(parser.parseOptionalLParen())) {
2572  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2574  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2575  return success();
2576  }
2577 
2578  // Parse keyword only attributes
2579  if (succeeded(parser.parseOptionalLSquare())) {
2580  if (failed(parser.parseCommaSeparatedList([&]() {
2581  if (parser.parseAttribute(gangAttrs.emplace_back()))
2582  return failure();
2583  return success();
2584  })))
2585  return failure();
2586  if (parser.parseRSquare())
2587  return failure();
2588  needCommaBeforeOperands = true;
2589  }
2590 
2591  if (needCommaBeforeOperands && failed(parser.parseComma()))
2592  return failure();
2593 
2594  if (failed(parser.parseCommaSeparatedList([&]() {
2595  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2596  parser.parseColon() ||
2597  parser.parseAttribute(gangDimAttrs.emplace_back()))
2598  return failure();
2599  if (succeeded(parser.parseOptionalLSquare())) {
2600  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2601  parser.parseRSquare())
2602  return failure();
2603  } else {
2604  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2605  parser.getContext(), mlir::acc::DeviceType::None));
2606  }
2607  return success();
2608  })))
2609  return failure();
2610 
2611  if (failed(parser.parseRParen()))
2612  return failure();
2613 
2614  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2615  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2616  gangDimDeviceTypes =
2617  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2618 
2619  return success();
2620 }
2621 
2623  std::optional<mlir::ArrayAttr> gang,
2624  std::optional<mlir::ArrayAttr> gangDim,
2625  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2626 
2627  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2628  gang->size() == 1) {
2629  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2630  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2631  return;
2632  }
2633 
2634  p << "(";
2635 
2636  printDeviceTypes(p, gang);
2637 
2638  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2639  p << ", ";
2640 
2641  if (hasDeviceTypeValues(gangDimDeviceTypes))
2642  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2643  [&](const auto &pair) {
2644  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2645  p << std::get<0>(pair);
2646  printSingleDeviceType(p, std::get<1>(pair));
2647  });
2648 
2649  p << ")";
2650 }
2651 
2652 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2653  mlir::ArrayAttr &deviceTypes) {
2655  // Keyword only
2656  if (failed(parser.parseOptionalLParen())) {
2657  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2659  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2660  return success();
2661  }
2662 
2663  // Parse device type attributes
2664  if (succeeded(parser.parseOptionalLSquare())) {
2665  if (failed(parser.parseCommaSeparatedList([&]() {
2666  if (parser.parseAttribute(attributes.emplace_back()))
2667  return failure();
2668  return success();
2669  })))
2670  return failure();
2671  if (parser.parseRSquare() || parser.parseRParen())
2672  return failure();
2673  }
2674  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2675  return success();
2676 }
2677 
2678 static void
2680  std::optional<mlir::ArrayAttr> deviceTypes) {
2681 
2682  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2683  auto deviceTypeAttr =
2684  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2685  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2686  return;
2687  }
2688 
2689  if (!hasDeviceTypeValues(deviceTypes))
2690  return;
2691 
2692  p << "([";
2693  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2694  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2695  p << dTypeAttr;
2696  });
2697  p << "])";
2698 }
2699 
2700 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2701 
2702 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2703  return hasDeviceType(getWorker(), deviceType);
2704 }
2705 
2706 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2707 
2708 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2709  return hasDeviceType(getVector(), deviceType);
2710 }
2711 
2712 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2713 
2714 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2715  return hasDeviceType(getSeq(), deviceType);
2716 }
2717 
2718 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2719  return getBindNameValue(mlir::acc::DeviceType::None);
2720 }
2721 
2722 std::optional<llvm::StringRef>
2723 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2724  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2725  return std::nullopt;
2726  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2727  auto attr = (*getBindName())[*pos];
2728  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2729  return stringAttr.getValue();
2730  }
2731  return std::nullopt;
2732 }
2733 
2734 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2735 
2736 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2737  return hasDeviceType(getGang(), deviceType);
2738 }
2739 
2740 std::optional<int64_t> RoutineOp::getGangDimValue() {
2741  return getGangDimValue(mlir::acc::DeviceType::None);
2742 }
2743 
2744 std::optional<int64_t>
2745 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2746  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2747  return std::nullopt;
2748  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2749  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2750  return intAttr.getInt();
2751  }
2752  return std::nullopt;
2753 }
2754 
2755 //===----------------------------------------------------------------------===//
2756 // InitOp
2757 //===----------------------------------------------------------------------===//
2758 
2759 LogicalResult acc::InitOp::verify() {
2760  Operation *currOp = *this;
2761  while ((currOp = currOp->getParentOp()))
2762  if (isComputeOperation(currOp))
2763  return emitOpError("cannot be nested in a compute operation");
2764  return success();
2765 }
2766 
2767 //===----------------------------------------------------------------------===//
2768 // ShutdownOp
2769 //===----------------------------------------------------------------------===//
2770 
2771 LogicalResult acc::ShutdownOp::verify() {
2772  Operation *currOp = *this;
2773  while ((currOp = currOp->getParentOp()))
2774  if (isComputeOperation(currOp))
2775  return emitOpError("cannot be nested in a compute operation");
2776  return success();
2777 }
2778 
2779 //===----------------------------------------------------------------------===//
2780 // SetOp
2781 //===----------------------------------------------------------------------===//
2782 
2783 LogicalResult acc::SetOp::verify() {
2784  Operation *currOp = *this;
2785  while ((currOp = currOp->getParentOp()))
2786  if (isComputeOperation(currOp))
2787  return emitOpError("cannot be nested in a compute operation");
2788  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2789  return emitOpError("at least one default_async, device_num, or device_type "
2790  "operand must appear");
2791  return success();
2792 }
2793 
2794 //===----------------------------------------------------------------------===//
2795 // UpdateOp
2796 //===----------------------------------------------------------------------===//
2797 
2798 LogicalResult acc::UpdateOp::verify() {
2799  // At least one of host or device should have a value.
2800  if (getDataClauseOperands().empty())
2801  return emitError("at least one value must be present in dataOperands");
2802 
2803  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2804  getAsyncOperandsDeviceTypeAttr(),
2805  "async")))
2806  return failure();
2807 
2809  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2810  getWaitOperandsDeviceTypeAttr(), "wait")))
2811  return failure();
2812 
2813  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2814  return failure();
2815 
2816  for (mlir::Value operand : getDataClauseOperands())
2817  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2818  operand.getDefiningOp()))
2819  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2820  "as defining op");
2821 
2822  return success();
2823 }
2824 
2825 unsigned UpdateOp::getNumDataOperands() {
2826  return getDataClauseOperands().size();
2827 }
2828 
2829 Value UpdateOp::getDataOperand(unsigned i) {
2830  unsigned numOptional = getAsyncOperands().size();
2831  numOptional += getIfCond() ? 1 : 0;
2832  return getOperand(getWaitOperands().size() + numOptional + i);
2833 }
2834 
2835 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2836  MLIRContext *context) {
2837  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2838 }
2839 
2840 bool UpdateOp::hasAsyncOnly() {
2841  return hasAsyncOnly(mlir::acc::DeviceType::None);
2842 }
2843 
2844 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2845  return hasDeviceType(getAsync(), deviceType);
2846 }
2847 
2848 mlir::Value UpdateOp::getAsyncValue() {
2849  return getAsyncValue(mlir::acc::DeviceType::None);
2850 }
2851 
2852 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2854  return {};
2855 
2856  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2857  return getAsyncOperands()[*pos];
2858 
2859  return {};
2860 }
2861 
2862 bool UpdateOp::hasWaitOnly() {
2863  return hasWaitOnly(mlir::acc::DeviceType::None);
2864 }
2865 
2866 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2867  return hasDeviceType(getWaitOnly(), deviceType);
2868 }
2869 
2870 mlir::Operation::operand_range UpdateOp::getWaitValues() {
2871  return getWaitValues(mlir::acc::DeviceType::None);
2872 }
2873 
2875 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2877  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2878  getHasWaitDevnum(), deviceType);
2879 }
2880 
2881 mlir::Value UpdateOp::getWaitDevnum() {
2882  return getWaitDevnum(mlir::acc::DeviceType::None);
2883 }
2884 
2885 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2886  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2887  getWaitOperandsSegments(), getHasWaitDevnum(),
2888  deviceType);
2889 }
2890 
2891 //===----------------------------------------------------------------------===//
2892 // WaitOp
2893 //===----------------------------------------------------------------------===//
2894 
2895 LogicalResult acc::WaitOp::verify() {
2896  // The async attribute represent the async clause without value. Therefore the
2897  // attribute and operand cannot appear at the same time.
2898  if (getAsyncOperand() && getAsync())
2899  return emitError("async attribute cannot appear with asyncOperand");
2900 
2901  if (getWaitDevnum() && getWaitOperands().empty())
2902  return emitError("wait_devnum cannot appear without waitOperands");
2903 
2904  return success();
2905 }
2906 
2907 #define GET_OP_CLASSES
2908 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2909 
2910 #define GET_ATTRDEF_CLASSES
2911 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2912 
2913 #define GET_TYPEDEF_CLASSES
2914 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2915 
2916 //===----------------------------------------------------------------------===//
2917 // acc dialect utilities
2918 //===----------------------------------------------------------------------===//
2919 
2921  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2922  .Case<ACC_DATA_ENTRY_OPS>(
2923  [&](auto entry) { return entry.getVarPtr(); })
2924  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2925  [&](auto exit) { return exit.getVarPtr(); })
2926  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2927  return varPtr;
2928 }
2929 
2931  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2933  [&](auto dataClause) { return dataClause.getAccPtr(); })
2934  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2935  return accPtr;
2936 }
2937 
2939  auto varPtrPtr{
2941  .Case<ACC_DATA_ENTRY_OPS>(
2942  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2943  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2944  return varPtrPtr;
2945 }
2946 
2951  accDataClauseOp)
2952  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2954  dataClause.getBounds().begin(), dataClause.getBounds().end());
2955  })
2956  .Default([&](mlir::Operation *) {
2958  })};
2959  return bounds;
2960 }
2961 
2965  accDataClauseOp)
2966  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2968  dataClause.getAsyncOperands().begin(),
2969  dataClause.getAsyncOperands().end());
2970  })
2971  .Default([&](mlir::Operation *) {
2973  });
2974 }
2975 
2976 mlir::ArrayAttr
2979  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2980  return dataClause.getAsyncOperandsDeviceTypeAttr();
2981  })
2982  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2983 }
2984 
2985 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
2988  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
2989  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2990 }
2991 
2992 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2993  auto name{
2995  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2996  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2997  return {};
2998  })};
2999  return name;
3000 }
3001 
3002 std::optional<mlir::acc::DataClause>
3004  auto dataClause{
3006  accDataEntryOp)
3007  .Case<ACC_DATA_ENTRY_OPS>(
3008  [&](auto entry) { return entry.getDataClause(); })
3009  .Default([&](mlir::Operation *) { return std::nullopt; })};
3010  return dataClause;
3011 }
3012 
3014  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3015  .Case<ACC_DATA_ENTRY_OPS>(
3016  [&](auto entry) { return entry.getImplicit(); })
3017  .Default([&](mlir::Operation *) { return false; })};
3018  return implicit;
3019 }
3020 
3022  auto dataOperands{
3025  [&](auto entry) { return entry.getDataClauseOperands(); })
3026  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3027  return dataOperands;
3028 }
3029 
3032  auto dataOperands{
3035  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3036  .Default([&](mlir::Operation *) { return nullptr; })};
3037  return dataOperands;
3038 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2220
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2622
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:490
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1841
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:792
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:1855
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:504
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1199
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2525
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1210
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1115
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:79
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2679
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1650
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1364
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2410
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:99
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2126
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:721
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1244
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:874
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:123
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:985
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2157
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2652
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2561
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1098
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1271
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2551
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1052
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:705
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:155
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:221
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1669
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:580
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1029
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:110
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:736
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1344
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:85
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1796
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:139
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1282
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:195
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:175
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:802
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2492
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1035
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1390
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:685
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:43
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:51
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2920
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:3003
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:3031
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:2948
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:3021
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2992
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:3013
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:2963
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:2938
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2930
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2985
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:140
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2977
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.