MLIR  20.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 using namespace acc;
24 
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
29 
30 namespace {
31 struct MemRefPointerLikeModel
32  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
33  MemRefType> {
34  Type getElementType(Type pointer) const {
35  return llvm::cast<MemRefType>(pointer).getElementType();
36  }
37 };
38 
39 struct LLVMPointerPointerLikeModel
40  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
41  LLVM::LLVMPointerType> {
42  Type getElementType(Type pointer) const { return Type(); }
43 };
44 } // namespace
45 
46 //===----------------------------------------------------------------------===//
47 // OpenACC operations
48 //===----------------------------------------------------------------------===//
49 
50 void OpenACCDialect::initialize() {
51  addOperations<
52 #define GET_OP_LIST
53 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
54  >();
55  addAttributes<
56 #define GET_ATTRDEF_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
58  >();
59  addTypes<
60 #define GET_TYPEDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
62  >();
63 
64  // By attaching interfaces here, we make the OpenACC dialect dependent on
65  // the other dialects. This is probably better than having dialects like LLVM
66  // and memref be dependent on OpenACC.
67  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
68  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
69  *getContext());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // device_type support helpers
74 //===----------------------------------------------------------------------===//
75 
76 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
77  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
78  return true;
79  return false;
80 }
81 
82 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
83  mlir::acc::DeviceType deviceType) {
84  if (!hasDeviceTypeValues(arrayAttr))
85  return false;
86 
87  for (auto attr : *arrayAttr) {
88  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89  if (deviceTypeAttr.getValue() == deviceType)
90  return true;
91  }
92 
93  return false;
94 }
95 
97  std::optional<mlir::ArrayAttr> deviceTypes) {
98  if (!hasDeviceTypeValues(deviceTypes))
99  return;
100 
101  p << "[";
102  llvm::interleaveComma(*deviceTypes, p,
103  [&](mlir::Attribute attr) { p << attr; });
104  p << "]";
105 }
106 
107 static std::optional<unsigned> findSegment(ArrayAttr segments,
108  mlir::acc::DeviceType deviceType) {
109  unsigned segmentIdx = 0;
110  for (auto attr : segments) {
111  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
112  if (deviceTypeAttr.getValue() == deviceType)
113  return std::make_optional(segmentIdx);
114  ++segmentIdx;
115  }
116  return std::nullopt;
117 }
118 
120 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
122  std::optional<llvm::ArrayRef<int32_t>> segments,
123  mlir::acc::DeviceType deviceType) {
124  if (!arrayAttr)
125  return range.take_front(0);
126  if (auto pos = findSegment(*arrayAttr, deviceType)) {
127  int32_t nbOperandsBefore = 0;
128  for (unsigned i = 0; i < *pos; ++i)
129  nbOperandsBefore += (*segments)[i];
130  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
131  }
132  return range.take_front(0);
133 }
134 
135 static mlir::Value
136 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
138  std::optional<llvm::ArrayRef<int32_t>> segments,
139  std::optional<mlir::ArrayAttr> hasWaitDevnum,
140  mlir::acc::DeviceType deviceType) {
141  if (!hasDeviceTypeValues(deviceTypeAttr))
142  return {};
143  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
144  if (hasWaitDevnum->getValue()[*pos])
145  return getValuesFromSegments(deviceTypeAttr, operands, segments,
146  deviceType)
147  .front();
148  return {};
149 }
150 
152 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
154  std::optional<llvm::ArrayRef<int32_t>> segments,
155  std::optional<mlir::ArrayAttr> hasWaitDevnum,
156  mlir::acc::DeviceType deviceType) {
157  auto range =
158  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
159  if (range.empty())
160  return range;
161  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
162  if (hasWaitDevnum && *hasWaitDevnum) {
163  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
164  if (boolAttr.getValue())
165  return range.drop_front(1); // first value is devnum
166  }
167  }
168  return range;
169 }
170 
171 template <typename Op>
172 static LogicalResult checkWaitAndAsyncConflict(Op op) {
173  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
174  ++dtypeInt) {
175  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
176 
177  // The async attribute represent the async clause without value. Therefore
178  // the attribute and operand cannot appear at the same time.
179  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
180  op.hasAsyncOnly(dtype))
181  return op.emitError("async attribute cannot appear with asyncOperand");
182 
183  // The wait attribute represent the wait clause without values. Therefore
184  // the attribute and operands cannot appear at the same time.
185  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
186  op.hasWaitOnly(dtype))
187  return op.emitError("wait attribute cannot appear with waitOperands");
188  }
189  return success();
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // DataBoundsOp
194 //===----------------------------------------------------------------------===//
195 LogicalResult acc::DataBoundsOp::verify() {
196  auto extent = getExtent();
197  auto upperbound = getUpperbound();
198  if (!extent && !upperbound)
199  return emitError("expected extent or upperbound.");
200  return success();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // PrivateOp
205 //===----------------------------------------------------------------------===//
206 LogicalResult acc::PrivateOp::verify() {
207  if (getDataClause() != acc::DataClause::acc_private)
208  return emitError(
209  "data clause associated with private operation must match its intent");
210  return success();
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // FirstprivateOp
215 //===----------------------------------------------------------------------===//
216 LogicalResult acc::FirstprivateOp::verify() {
217  if (getDataClause() != acc::DataClause::acc_firstprivate)
218  return emitError("data clause associated with firstprivate operation must "
219  "match its intent");
220  return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // ReductionOp
225 //===----------------------------------------------------------------------===//
226 LogicalResult acc::ReductionOp::verify() {
227  if (getDataClause() != acc::DataClause::acc_reduction)
228  return emitError("data clause associated with reduction operation must "
229  "match its intent");
230  return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // DevicePtrOp
235 //===----------------------------------------------------------------------===//
236 LogicalResult acc::DevicePtrOp::verify() {
237  if (getDataClause() != acc::DataClause::acc_deviceptr)
238  return emitError("data clause associated with deviceptr operation must "
239  "match its intent");
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // PresentOp
245 //===----------------------------------------------------------------------===//
246 LogicalResult acc::PresentOp::verify() {
247  if (getDataClause() != acc::DataClause::acc_present)
248  return emitError(
249  "data clause associated with present operation must match its intent");
250  return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CopyinOp
255 //===----------------------------------------------------------------------===//
256 LogicalResult acc::CopyinOp::verify() {
257  // Test for all clauses this operation can be decomposed from:
258  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
259  getDataClause() != acc::DataClause::acc_copyin_readonly &&
260  getDataClause() != acc::DataClause::acc_copy &&
261  getDataClause() != acc::DataClause::acc_reduction)
262  return emitError(
263  "data clause associated with copyin operation must match its intent"
264  " or specify original clause this operation was decomposed from");
265  return success();
266 }
267 
268 bool acc::CopyinOp::isCopyinReadonly() {
269  return getDataClause() == acc::DataClause::acc_copyin_readonly;
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // CreateOp
274 //===----------------------------------------------------------------------===//
275 LogicalResult acc::CreateOp::verify() {
276  // Test for all clauses this operation can be decomposed from:
277  if (getDataClause() != acc::DataClause::acc_create &&
278  getDataClause() != acc::DataClause::acc_create_zero &&
279  getDataClause() != acc::DataClause::acc_copyout &&
280  getDataClause() != acc::DataClause::acc_copyout_zero)
281  return emitError(
282  "data clause associated with create operation must match its intent"
283  " or specify original clause this operation was decomposed from");
284  return success();
285 }
286 
287 bool acc::CreateOp::isCreateZero() {
288  // The zero modifier is encoded in the data clause.
289  return getDataClause() == acc::DataClause::acc_create_zero ||
290  getDataClause() == acc::DataClause::acc_copyout_zero;
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // NoCreateOp
295 //===----------------------------------------------------------------------===//
296 LogicalResult acc::NoCreateOp::verify() {
297  if (getDataClause() != acc::DataClause::acc_no_create)
298  return emitError("data clause associated with no_create operation must "
299  "match its intent");
300  return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // AttachOp
305 //===----------------------------------------------------------------------===//
306 LogicalResult acc::AttachOp::verify() {
307  if (getDataClause() != acc::DataClause::acc_attach)
308  return emitError(
309  "data clause associated with attach operation must match its intent");
310  return success();
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // DeclareDeviceResidentOp
315 //===----------------------------------------------------------------------===//
316 
317 LogicalResult acc::DeclareDeviceResidentOp::verify() {
318  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
319  return emitError("data clause associated with device_resident operation "
320  "must match its intent");
321  return success();
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // DeclareLinkOp
326 //===----------------------------------------------------------------------===//
327 
328 LogicalResult acc::DeclareLinkOp::verify() {
329  if (getDataClause() != acc::DataClause::acc_declare_link)
330  return emitError(
331  "data clause associated with link operation must match its intent");
332  return success();
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // CopyoutOp
337 //===----------------------------------------------------------------------===//
338 LogicalResult acc::CopyoutOp::verify() {
339  // Test for all clauses this operation can be decomposed from:
340  if (getDataClause() != acc::DataClause::acc_copyout &&
341  getDataClause() != acc::DataClause::acc_copyout_zero &&
342  getDataClause() != acc::DataClause::acc_copy &&
343  getDataClause() != acc::DataClause::acc_reduction)
344  return emitError(
345  "data clause associated with copyout operation must match its intent"
346  " or specify original clause this operation was decomposed from");
347  if (!getVarPtr() || !getAccPtr())
348  return emitError("must have both host and device pointers");
349  return success();
350 }
351 
352 bool acc::CopyoutOp::isCopyoutZero() {
353  return getDataClause() == acc::DataClause::acc_copyout_zero;
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // DeleteOp
358 //===----------------------------------------------------------------------===//
359 LogicalResult acc::DeleteOp::verify() {
360  // Test for all clauses this operation can be decomposed from:
361  if (getDataClause() != acc::DataClause::acc_delete &&
362  getDataClause() != acc::DataClause::acc_create &&
363  getDataClause() != acc::DataClause::acc_create_zero &&
364  getDataClause() != acc::DataClause::acc_copyin &&
365  getDataClause() != acc::DataClause::acc_copyin_readonly &&
366  getDataClause() != acc::DataClause::acc_present &&
367  getDataClause() != acc::DataClause::acc_declare_device_resident &&
368  getDataClause() != acc::DataClause::acc_declare_link)
369  return emitError(
370  "data clause associated with delete operation must match its intent"
371  " or specify original clause this operation was decomposed from");
372  if (!getAccPtr())
373  return emitError("must have device pointer");
374  return success();
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // DetachOp
379 //===----------------------------------------------------------------------===//
380 LogicalResult acc::DetachOp::verify() {
381  // Test for all clauses this operation can be decomposed from:
382  if (getDataClause() != acc::DataClause::acc_detach &&
383  getDataClause() != acc::DataClause::acc_attach)
384  return emitError(
385  "data clause associated with detach operation must match its intent"
386  " or specify original clause this operation was decomposed from");
387  if (!getAccPtr())
388  return emitError("must have device pointer");
389  return success();
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // HostOp
394 //===----------------------------------------------------------------------===//
395 LogicalResult acc::UpdateHostOp::verify() {
396  // Test for all clauses this operation can be decomposed from:
397  if (getDataClause() != acc::DataClause::acc_update_host &&
398  getDataClause() != acc::DataClause::acc_update_self)
399  return emitError(
400  "data clause associated with host operation must match its intent"
401  " or specify original clause this operation was decomposed from");
402  if (!getVarPtr() || !getAccPtr())
403  return emitError("must have both host and device pointers");
404  return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // DeviceOp
409 //===----------------------------------------------------------------------===//
410 LogicalResult acc::UpdateDeviceOp::verify() {
411  // Test for all clauses this operation can be decomposed from:
412  if (getDataClause() != acc::DataClause::acc_update_device)
413  return emitError(
414  "data clause associated with device operation must match its intent"
415  " or specify original clause this operation was decomposed from");
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // UseDeviceOp
421 //===----------------------------------------------------------------------===//
422 LogicalResult acc::UseDeviceOp::verify() {
423  // Test for all clauses this operation can be decomposed from:
424  if (getDataClause() != acc::DataClause::acc_use_device)
425  return emitError(
426  "data clause associated with use_device operation must match its intent"
427  " or specify original clause this operation was decomposed from");
428  return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // CacheOp
433 //===----------------------------------------------------------------------===//
434 LogicalResult acc::CacheOp::verify() {
435  // Test for all clauses this operation can be decomposed from:
436  if (getDataClause() != acc::DataClause::acc_cache &&
437  getDataClause() != acc::DataClause::acc_cache_readonly)
438  return emitError(
439  "data clause associated with cache operation must match its intent"
440  " or specify original clause this operation was decomposed from");
441  return success();
442 }
443 
444 template <typename StructureOp>
445 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
446  unsigned nRegions = 1) {
447 
448  SmallVector<Region *, 2> regions;
449  for (unsigned i = 0; i < nRegions; ++i)
450  regions.push_back(state.addRegion());
451 
452  for (Region *region : regions)
453  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
454  return failure();
455 
456  return success();
457 }
458 
459 static bool isComputeOperation(Operation *op) {
460  return isa<acc::ParallelOp, acc::LoopOp>(op);
461 }
462 
463 namespace {
464 /// Pattern to remove operation without region that have constant false `ifCond`
465 /// and remove the condition from the operation if the `ifCond` is a true
466 /// constant.
467 template <typename OpTy>
468 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
470 
471  LogicalResult matchAndRewrite(OpTy op,
472  PatternRewriter &rewriter) const override {
473  // Early return if there is no condition.
474  Value ifCond = op.getIfCond();
475  if (!ifCond)
476  return failure();
477 
478  IntegerAttr constAttr;
479  if (!matchPattern(ifCond, m_Constant(&constAttr)))
480  return failure();
481  if (constAttr.getInt())
482  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
483  else
484  rewriter.eraseOp(op);
485 
486  return success();
487  }
488 };
489 
490 /// Replaces the given op with the contents of the given single-block region,
491 /// using the operands of the block terminator to replace operation results.
492 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
493  Region &region, ValueRange blockArgs = {}) {
494  assert(llvm::hasSingleElement(region) && "expected single-region block");
495  Block *block = &region.front();
496  Operation *terminator = block->getTerminator();
497  ValueRange results = terminator->getOperands();
498  rewriter.inlineBlockBefore(block, op, blockArgs);
499  rewriter.replaceOp(op, results);
500  rewriter.eraseOp(terminator);
501 }
502 
503 /// Pattern to remove operation with region that have constant false `ifCond`
504 /// and remove the condition from the operation if the `ifCond` is constant
505 /// true.
506 template <typename OpTy>
507 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
509 
510  LogicalResult matchAndRewrite(OpTy op,
511  PatternRewriter &rewriter) const override {
512  // Early return if there is no condition.
513  Value ifCond = op.getIfCond();
514  if (!ifCond)
515  return failure();
516 
517  IntegerAttr constAttr;
518  if (!matchPattern(ifCond, m_Constant(&constAttr)))
519  return failure();
520  if (constAttr.getInt())
521  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
522  else
523  replaceOpWithRegion(rewriter, op, op.getRegion());
524 
525  return success();
526  }
527 };
528 
529 } // namespace
530 
531 //===----------------------------------------------------------------------===//
532 // PrivateRecipeOp
533 //===----------------------------------------------------------------------===//
534 
535 static LogicalResult verifyInitLikeSingleArgRegion(
536  Operation *op, Region &region, StringRef regionType, StringRef regionName,
537  Type type, bool verifyYield, bool optional = false) {
538  if (optional && region.empty())
539  return success();
540 
541  if (region.empty())
542  return op->emitOpError() << "expects non-empty " << regionName << " region";
543  Block &firstBlock = region.front();
544  if (firstBlock.getNumArguments() < 1 ||
545  firstBlock.getArgument(0).getType() != type)
546  return op->emitOpError() << "expects " << regionName
547  << " region first "
548  "argument of the "
549  << regionType << " type";
550 
551  if (verifyYield) {
552  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
553  if (yieldOp.getOperands().size() != 1 ||
554  yieldOp.getOperands().getTypes()[0] != type)
555  return op->emitOpError() << "expects " << regionName
556  << " region to "
557  "yield a value of the "
558  << regionType << " type";
559  }
560  }
561  return success();
562 }
563 
564 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
565  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
566  "privatization", "init", getType(),
567  /*verifyYield=*/false)))
568  return failure();
570  *this, getDestroyRegion(), "privatization", "destroy", getType(),
571  /*verifyYield=*/false, /*optional=*/true)))
572  return failure();
573  return success();
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // FirstprivateRecipeOp
578 //===----------------------------------------------------------------------===//
579 
580 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
581  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
582  "privatization", "init", getType(),
583  /*verifyYield=*/false)))
584  return failure();
585 
586  if (getCopyRegion().empty())
587  return emitOpError() << "expects non-empty copy region";
588 
589  Block &firstBlock = getCopyRegion().front();
590  if (firstBlock.getNumArguments() < 2 ||
591  firstBlock.getArgument(0).getType() != getType())
592  return emitOpError() << "expects copy region with two arguments of the "
593  "privatization type";
594 
595  if (getDestroyRegion().empty())
596  return success();
597 
598  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
599  "privatization", "destroy",
600  getType(), /*verifyYield=*/false)))
601  return failure();
602 
603  return success();
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // ReductionRecipeOp
608 //===----------------------------------------------------------------------===//
609 
610 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
611  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
612  "init", getType(),
613  /*verifyYield=*/false)))
614  return failure();
615 
616  if (getCombinerRegion().empty())
617  return emitOpError() << "expects non-empty combiner region";
618 
619  Block &reductionBlock = getCombinerRegion().front();
620  if (reductionBlock.getNumArguments() < 2 ||
621  reductionBlock.getArgument(0).getType() != getType() ||
622  reductionBlock.getArgument(1).getType() != getType())
623  return emitOpError() << "expects combiner region with the first two "
624  << "arguments of the reduction type";
625 
626  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
627  if (yieldOp.getOperands().size() != 1 ||
628  yieldOp.getOperands().getTypes()[0] != getType())
629  return emitOpError() << "expects combiner region to yield a value "
630  "of the reduction type";
631  }
632 
633  return success();
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // Custom parser and printer verifier for private clause
638 //===----------------------------------------------------------------------===//
639 
640 static ParseResult parseSymOperandList(
641  mlir::OpAsmParser &parser,
643  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
645  if (failed(parser.parseCommaSeparatedList([&]() {
646  if (parser.parseAttribute(attributes.emplace_back()) ||
647  parser.parseArrow() ||
648  parser.parseOperand(operands.emplace_back()) ||
649  parser.parseColonType(types.emplace_back()))
650  return failure();
651  return success();
652  })))
653  return failure();
654  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
655  attributes.end());
656  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
657  return success();
658 }
659 
661  mlir::OperandRange operands,
662  mlir::TypeRange types,
663  std::optional<mlir::ArrayAttr> attributes) {
664  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
665  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
666  << std::get<1>(it).getType();
667  });
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // ParallelOp
672 //===----------------------------------------------------------------------===//
673 
674 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
675 template <typename Op>
676 static LogicalResult checkDataOperands(Op op,
677  const mlir::ValueRange &operands) {
678  for (mlir::Value operand : operands)
679  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
680  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
681  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
682  operand.getDefiningOp()))
683  return op.emitError(
684  "expect data entry/exit operation or acc.getdeviceptr "
685  "as defining op");
686  return success();
687 }
688 
689 template <typename Op>
690 static LogicalResult
691 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
692  mlir::OperandRange operands, llvm::StringRef operandName,
693  llvm::StringRef symbolName, bool checkOperandType = true) {
694  if (!operands.empty()) {
695  if (!attributes || attributes->size() != operands.size())
696  return op->emitOpError()
697  << "expected as many " << symbolName << " symbol reference as "
698  << operandName << " operands";
699  } else {
700  if (attributes)
701  return op->emitOpError()
702  << "unexpected " << symbolName << " symbol reference";
703  return success();
704  }
705 
707  for (auto args : llvm::zip(operands, *attributes)) {
708  mlir::Value operand = std::get<0>(args);
709 
710  if (!set.insert(operand).second)
711  return op->emitOpError()
712  << operandName << " operand appears more than once";
713 
714  mlir::Type varType = operand.getType();
715  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
716  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
717  if (!decl)
718  return op->emitOpError()
719  << "expected symbol reference " << symbolRef << " to point to a "
720  << operandName << " declaration";
721 
722  if (checkOperandType && decl.getType() && decl.getType() != varType)
723  return op->emitOpError() << "expected " << operandName << " (" << varType
724  << ") to be the same type as " << operandName
725  << " declaration (" << decl.getType() << ")";
726  }
727 
728  return success();
729 }
730 
731 unsigned ParallelOp::getNumDataOperands() {
732  return getReductionOperands().size() + getGangPrivateOperands().size() +
733  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
734 }
735 
736 Value ParallelOp::getDataOperand(unsigned i) {
737  unsigned numOptional = getAsyncOperands().size();
738  numOptional += getNumGangs().size();
739  numOptional += getNumWorkers().size();
740  numOptional += getVectorLength().size();
741  numOptional += getIfCond() ? 1 : 0;
742  numOptional += getSelfCond() ? 1 : 0;
743  return getOperand(getWaitOperands().size() + numOptional + i);
744 }
745 
746 template <typename Op>
747 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
748  ArrayAttr deviceTypes,
749  llvm::StringRef keyword) {
750  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
751  return op.emitOpError() << keyword << " operands count must match "
752  << keyword << " device_type count";
753  return success();
754 }
755 
756 template <typename Op>
758  Op op, OperandRange operands, DenseI32ArrayAttr segments,
759  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
760  std::size_t numOperandsInSegments = 0;
761 
762  if (!segments)
763  return success();
764 
765  for (auto segCount : segments.asArrayRef()) {
766  if (maxInSegment != 0 && segCount > maxInSegment)
767  return op.emitOpError() << keyword << " expects a maximum of "
768  << maxInSegment << " values per segment";
769  numOperandsInSegments += segCount;
770  }
771  if (numOperandsInSegments != operands.size())
772  return op.emitOpError()
773  << keyword << " operand count does not match count in segments";
774  if (deviceTypes.getValue().size() != (size_t)segments.size())
775  return op.emitOpError()
776  << keyword << " segment count does not match device_type count";
777  return success();
778 }
779 
780 LogicalResult acc::ParallelOp::verify() {
781  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
782  *this, getPrivatizations(), getGangPrivateOperands(), "private",
783  "privatizations", /*checkOperandType=*/false)))
784  return failure();
785  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
786  *this, getReductionRecipes(), getReductionOperands(), "reduction",
787  "reductions", false)))
788  return failure();
789 
791  *this, getNumGangs(), getNumGangsSegmentsAttr(),
792  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
793  return failure();
794 
796  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
797  getWaitOperandsDeviceTypeAttr(), "wait")))
798  return failure();
799 
800  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
801  getNumWorkersDeviceTypeAttr(),
802  "num_workers")))
803  return failure();
804 
805  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
806  getVectorLengthDeviceTypeAttr(),
807  "vector_length")))
808  return failure();
809 
810  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
811  getAsyncOperandsDeviceTypeAttr(),
812  "async")))
813  return failure();
814 
815  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
816  return failure();
817 
818  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
819 }
820 
821 static mlir::Value
822 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
824  mlir::acc::DeviceType deviceType) {
825  if (!arrayAttr)
826  return {};
827  if (auto pos = findSegment(*arrayAttr, deviceType))
828  return range[*pos];
829  return {};
830 }
831 
832 bool acc::ParallelOp::hasAsyncOnly() {
833  return hasAsyncOnly(mlir::acc::DeviceType::None);
834 }
835 
836 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
837  return hasDeviceType(getAsyncOnly(), deviceType);
838 }
839 
840 mlir::Value acc::ParallelOp::getAsyncValue() {
841  return getAsyncValue(mlir::acc::DeviceType::None);
842 }
843 
844 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
846  getAsyncOperands(), deviceType);
847 }
848 
849 mlir::Value acc::ParallelOp::getNumWorkersValue() {
850  return getNumWorkersValue(mlir::acc::DeviceType::None);
851 }
852 
854 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
855  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
856  deviceType);
857 }
858 
859 mlir::Value acc::ParallelOp::getVectorLengthValue() {
860  return getVectorLengthValue(mlir::acc::DeviceType::None);
861 }
862 
864 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
865  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
866  getVectorLength(), deviceType);
867 }
868 
869 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
870  return getNumGangsValues(mlir::acc::DeviceType::None);
871 }
872 
874 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
875  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
876  getNumGangsSegments(), deviceType);
877 }
878 
879 bool acc::ParallelOp::hasWaitOnly() {
880  return hasWaitOnly(mlir::acc::DeviceType::None);
881 }
882 
883 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
884  return hasDeviceType(getWaitOnly(), deviceType);
885 }
886 
887 mlir::Operation::operand_range ParallelOp::getWaitValues() {
888  return getWaitValues(mlir::acc::DeviceType::None);
889 }
890 
892 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
894  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
895  getHasWaitDevnum(), deviceType);
896 }
897 
898 mlir::Value ParallelOp::getWaitDevnum() {
899  return getWaitDevnum(mlir::acc::DeviceType::None);
900 }
901 
902 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
903  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
904  getWaitOperandsSegments(), getHasWaitDevnum(),
905  deviceType);
906 }
907 
908 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
909  mlir::OperationState &odsState,
910  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
911  mlir::ValueRange vectorLength,
912  mlir::ValueRange asyncOperands,
913  mlir::ValueRange waitOperands, mlir::Value ifCond,
914  mlir::Value selfCond, mlir::ValueRange reductionOperands,
915  mlir::ValueRange gangPrivateOperands,
916  mlir::ValueRange gangFirstPrivateOperands,
917  mlir::ValueRange dataClauseOperands) {
918 
919  ParallelOp::build(
920  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
921  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
922  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
923  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
924  /*numGangsDeviceType=*/nullptr, numWorkers,
925  /*numWorkersDeviceType=*/nullptr, vectorLength,
926  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
927  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
928  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
929  /*firstprivatizations=*/nullptr, dataClauseOperands,
930  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
931 }
932 
933 static ParseResult parseNumGangs(
934  mlir::OpAsmParser &parser,
936  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
937  mlir::DenseI32ArrayAttr &segments) {
940 
941  do {
942  if (failed(parser.parseLBrace()))
943  return failure();
944 
945  int32_t crtOperandsSize = operands.size();
946  if (failed(parser.parseCommaSeparatedList(
948  if (parser.parseOperand(operands.emplace_back()) ||
949  parser.parseColonType(types.emplace_back()))
950  return failure();
951  return success();
952  })))
953  return failure();
954  seg.push_back(operands.size() - crtOperandsSize);
955 
956  if (failed(parser.parseRBrace()))
957  return failure();
958 
959  if (succeeded(parser.parseOptionalLSquare())) {
960  if (parser.parseAttribute(attributes.emplace_back()) ||
961  parser.parseRSquare())
962  return failure();
963  } else {
964  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
966  }
967  } while (succeeded(parser.parseOptionalComma()));
968 
969  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
970  attributes.end());
971  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
972  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
973 
974  return success();
975 }
976 
978  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
979  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
980  p << " [" << attr << "]";
981 }
982 
984  mlir::OperandRange operands, mlir::TypeRange types,
985  std::optional<mlir::ArrayAttr> deviceTypes,
986  std::optional<mlir::DenseI32ArrayAttr> segments) {
987  unsigned opIdx = 0;
988  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
989  p << "{";
990  llvm::interleaveComma(
991  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
992  p << operands[opIdx] << " : " << operands[opIdx].getType();
993  ++opIdx;
994  });
995  p << "}";
996  printSingleDeviceType(p, it.value());
997  });
998 }
999 
1001  mlir::OpAsmParser &parser,
1003  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1004  mlir::DenseI32ArrayAttr &segments) {
1007 
1008  do {
1009  if (failed(parser.parseLBrace()))
1010  return failure();
1011 
1012  int32_t crtOperandsSize = operands.size();
1013 
1014  if (failed(parser.parseCommaSeparatedList(
1016  if (parser.parseOperand(operands.emplace_back()) ||
1017  parser.parseColonType(types.emplace_back()))
1018  return failure();
1019  return success();
1020  })))
1021  return failure();
1022 
1023  seg.push_back(operands.size() - crtOperandsSize);
1024 
1025  if (failed(parser.parseRBrace()))
1026  return failure();
1027 
1028  if (succeeded(parser.parseOptionalLSquare())) {
1029  if (parser.parseAttribute(attributes.emplace_back()) ||
1030  parser.parseRSquare())
1031  return failure();
1032  } else {
1033  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1035  }
1036  } while (succeeded(parser.parseOptionalComma()));
1037 
1038  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1039  attributes.end());
1040  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1041  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1042 
1043  return success();
1044 }
1045 
1048  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1049  std::optional<mlir::DenseI32ArrayAttr> segments) {
1050  unsigned opIdx = 0;
1051  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1052  p << "{";
1053  llvm::interleaveComma(
1054  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1055  p << operands[opIdx] << " : " << operands[opIdx].getType();
1056  ++opIdx;
1057  });
1058  p << "}";
1059  printSingleDeviceType(p, it.value());
1060  });
1061 }
1062 
1063 static ParseResult parseWaitClause(
1064  mlir::OpAsmParser &parser,
1066  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1067  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1068  mlir::ArrayAttr &keywordOnly) {
1069  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1071 
1072  bool needCommaBeforeOperands = false;
1073 
1074  // Keyword only
1075  if (failed(parser.parseOptionalLParen())) {
1076  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1078  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1079  return success();
1080  }
1081 
1082  // Parse keyword only attributes
1083  if (succeeded(parser.parseOptionalLSquare())) {
1084  if (failed(parser.parseCommaSeparatedList([&]() {
1085  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1086  return failure();
1087  return success();
1088  })))
1089  return failure();
1090  if (parser.parseRSquare())
1091  return failure();
1092  needCommaBeforeOperands = true;
1093  }
1094 
1095  if (needCommaBeforeOperands && failed(parser.parseComma()))
1096  return failure();
1097 
1098  do {
1099  if (failed(parser.parseLBrace()))
1100  return failure();
1101 
1102  int32_t crtOperandsSize = operands.size();
1103 
1104  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1105  if (failed(parser.parseColon()))
1106  return failure();
1107  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1108  } else {
1109  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1110  }
1111 
1112  if (failed(parser.parseCommaSeparatedList(
1114  if (parser.parseOperand(operands.emplace_back()) ||
1115  parser.parseColonType(types.emplace_back()))
1116  return failure();
1117  return success();
1118  })))
1119  return failure();
1120 
1121  seg.push_back(operands.size() - crtOperandsSize);
1122 
1123  if (failed(parser.parseRBrace()))
1124  return failure();
1125 
1126  if (succeeded(parser.parseOptionalLSquare())) {
1127  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1128  parser.parseRSquare())
1129  return failure();
1130  } else {
1131  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1133  }
1134  } while (succeeded(parser.parseOptionalComma()));
1135 
1136  if (failed(parser.parseRParen()))
1137  return failure();
1138 
1139  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1140  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1141  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1142  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1143 
1144  return success();
1145 }
1146 
1147 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1148  if (!hasDeviceTypeValues(attrs))
1149  return false;
1150  if (attrs->size() != 1)
1151  return false;
1152  if (auto deviceTypeAttr =
1153  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1154  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1155  return false;
1156 }
1157 
1159  mlir::OperandRange operands, mlir::TypeRange types,
1160  std::optional<mlir::ArrayAttr> deviceTypes,
1161  std::optional<mlir::DenseI32ArrayAttr> segments,
1162  std::optional<mlir::ArrayAttr> hasDevNum,
1163  std::optional<mlir::ArrayAttr> keywordOnly) {
1164 
1165  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1166  return;
1167 
1168  p << "(";
1169 
1170  printDeviceTypes(p, keywordOnly);
1171  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1172  p << ", ";
1173 
1174  unsigned opIdx = 0;
1175  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1176  p << "{";
1177  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1178  if (boolAttr && boolAttr.getValue())
1179  p << "devnum: ";
1180  llvm::interleaveComma(
1181  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1182  p << operands[opIdx] << " : " << operands[opIdx].getType();
1183  ++opIdx;
1184  });
1185  p << "}";
1186  printSingleDeviceType(p, it.value());
1187  });
1188 
1189  p << ")";
1190 }
1191 
1192 static ParseResult parseDeviceTypeOperands(
1193  mlir::OpAsmParser &parser,
1195  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1197  if (failed(parser.parseCommaSeparatedList([&]() {
1198  if (parser.parseOperand(operands.emplace_back()) ||
1199  parser.parseColonType(types.emplace_back()))
1200  return failure();
1201  if (succeeded(parser.parseOptionalLSquare())) {
1202  if (parser.parseAttribute(attributes.emplace_back()) ||
1203  parser.parseRSquare())
1204  return failure();
1205  } else {
1206  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1207  parser.getContext(), mlir::acc::DeviceType::None));
1208  }
1209  return success();
1210  })))
1211  return failure();
1212  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1213  attributes.end());
1214  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1215  return success();
1216 }
1217 
1218 static void
1220  mlir::OperandRange operands, mlir::TypeRange types,
1221  std::optional<mlir::ArrayAttr> deviceTypes) {
1222  if (!hasDeviceTypeValues(deviceTypes))
1223  return;
1224  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1225  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1226  printSingleDeviceType(p, std::get<0>(it));
1227  });
1228 }
1229 
1231  mlir::OpAsmParser &parser,
1233  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1234  mlir::ArrayAttr &keywordOnlyDeviceType) {
1235 
1236  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1237  bool needCommaBeforeOperands = false;
1238 
1239  if (failed(parser.parseOptionalLParen())) {
1240  // Keyword only
1241  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1243  keywordOnlyDeviceType =
1244  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1245  return success();
1246  }
1247 
1248  // Parse keyword only attributes
1249  if (succeeded(parser.parseOptionalLSquare())) {
1250  // Parse keyword only attributes
1251  if (failed(parser.parseCommaSeparatedList([&]() {
1252  if (parser.parseAttribute(
1253  keywordOnlyDeviceTypeAttributes.emplace_back()))
1254  return failure();
1255  return success();
1256  })))
1257  return failure();
1258  if (parser.parseRSquare())
1259  return failure();
1260  needCommaBeforeOperands = true;
1261  }
1262 
1263  if (needCommaBeforeOperands && failed(parser.parseComma()))
1264  return failure();
1265 
1267  if (failed(parser.parseCommaSeparatedList([&]() {
1268  if (parser.parseOperand(operands.emplace_back()) ||
1269  parser.parseColonType(types.emplace_back()))
1270  return failure();
1271  if (succeeded(parser.parseOptionalLSquare())) {
1272  if (parser.parseAttribute(attributes.emplace_back()) ||
1273  parser.parseRSquare())
1274  return failure();
1275  } else {
1276  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1277  parser.getContext(), mlir::acc::DeviceType::None));
1278  }
1279  return success();
1280  })))
1281  return failure();
1282 
1283  if (failed(parser.parseRParen()))
1284  return failure();
1285 
1286  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1287  attributes.end());
1288  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1289  return success();
1290 }
1291 
1294  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1295  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1296 
1297  if (operands.begin() == operands.end() &&
1298  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1299  return;
1300  }
1301 
1302  p << "(";
1303  printDeviceTypes(p, keywordOnlyDeviceTypes);
1304  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1305  hasDeviceTypeValues(deviceTypes))
1306  p << ", ";
1307  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1308  p << ")";
1309 }
1310 
1311 static ParseResult
1313  mlir::acc::CombinedConstructsTypeAttr &attr) {
1314  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1315  if (parser.parseLParen())
1316  return failure();
1317  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1319  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1320  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1322  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1323  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1325  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1326  } else {
1327  parser.emitError(parser.getCurrentLocation(),
1328  "expected compute construct name");
1329  return failure();
1330  }
1331  if (parser.parseRParen())
1332  return failure();
1333  }
1334  return success();
1335 }
1336 
1337 static void
1339  mlir::acc::CombinedConstructsTypeAttr attr) {
1340  if (attr) {
1341  switch (attr.getValue()) {
1342  case mlir::acc::CombinedConstructsType::KernelsLoop:
1343  p << "combined(kernels)";
1344  break;
1345  case mlir::acc::CombinedConstructsType::ParallelLoop:
1346  p << "combined(parallel)";
1347  break;
1348  case mlir::acc::CombinedConstructsType::SerialLoop:
1349  p << "combined(serial)";
1350  break;
1351  };
1352  }
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // SerialOp
1357 //===----------------------------------------------------------------------===//
1358 
1359 unsigned SerialOp::getNumDataOperands() {
1360  return getReductionOperands().size() + getGangPrivateOperands().size() +
1361  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1362 }
1363 
1364 Value SerialOp::getDataOperand(unsigned i) {
1365  unsigned numOptional = getAsyncOperands().size();
1366  numOptional += getIfCond() ? 1 : 0;
1367  numOptional += getSelfCond() ? 1 : 0;
1368  return getOperand(getWaitOperands().size() + numOptional + i);
1369 }
1370 
1371 bool acc::SerialOp::hasAsyncOnly() {
1372  return hasAsyncOnly(mlir::acc::DeviceType::None);
1373 }
1374 
1375 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1376  return hasDeviceType(getAsyncOnly(), deviceType);
1377 }
1378 
1379 mlir::Value acc::SerialOp::getAsyncValue() {
1380  return getAsyncValue(mlir::acc::DeviceType::None);
1381 }
1382 
1383 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1385  getAsyncOperands(), deviceType);
1386 }
1387 
1388 bool acc::SerialOp::hasWaitOnly() {
1389  return hasWaitOnly(mlir::acc::DeviceType::None);
1390 }
1391 
1392 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1393  return hasDeviceType(getWaitOnly(), deviceType);
1394 }
1395 
1396 mlir::Operation::operand_range SerialOp::getWaitValues() {
1397  return getWaitValues(mlir::acc::DeviceType::None);
1398 }
1399 
1401 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1403  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1404  getHasWaitDevnum(), deviceType);
1405 }
1406 
1407 mlir::Value SerialOp::getWaitDevnum() {
1408  return getWaitDevnum(mlir::acc::DeviceType::None);
1409 }
1410 
1411 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1412  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1413  getWaitOperandsSegments(), getHasWaitDevnum(),
1414  deviceType);
1415 }
1416 
1417 LogicalResult acc::SerialOp::verify() {
1418  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1419  *this, getPrivatizations(), getGangPrivateOperands(), "private",
1420  "privatizations", /*checkOperandType=*/false)))
1421  return failure();
1422  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1423  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1424  "reductions", false)))
1425  return failure();
1426 
1428  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1429  getWaitOperandsDeviceTypeAttr(), "wait")))
1430  return failure();
1431 
1432  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1433  getAsyncOperandsDeviceTypeAttr(),
1434  "async")))
1435  return failure();
1436 
1437  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1438  return failure();
1439 
1440  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // KernelsOp
1445 //===----------------------------------------------------------------------===//
1446 
1447 unsigned KernelsOp::getNumDataOperands() {
1448  return getDataClauseOperands().size();
1449 }
1450 
1451 Value KernelsOp::getDataOperand(unsigned i) {
1452  unsigned numOptional = getAsyncOperands().size();
1453  numOptional += getWaitOperands().size();
1454  numOptional += getNumGangs().size();
1455  numOptional += getNumWorkers().size();
1456  numOptional += getVectorLength().size();
1457  numOptional += getIfCond() ? 1 : 0;
1458  numOptional += getSelfCond() ? 1 : 0;
1459  return getOperand(numOptional + i);
1460 }
1461 
1462 bool acc::KernelsOp::hasAsyncOnly() {
1463  return hasAsyncOnly(mlir::acc::DeviceType::None);
1464 }
1465 
1466 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1467  return hasDeviceType(getAsyncOnly(), deviceType);
1468 }
1469 
1470 mlir::Value acc::KernelsOp::getAsyncValue() {
1471  return getAsyncValue(mlir::acc::DeviceType::None);
1472 }
1473 
1474 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1476  getAsyncOperands(), deviceType);
1477 }
1478 
1479 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1480  return getNumWorkersValue(mlir::acc::DeviceType::None);
1481 }
1482 
1484 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1485  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1486  deviceType);
1487 }
1488 
1489 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1490  return getVectorLengthValue(mlir::acc::DeviceType::None);
1491 }
1492 
1494 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1495  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1496  getVectorLength(), deviceType);
1497 }
1498 
1499 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1500  return getNumGangsValues(mlir::acc::DeviceType::None);
1501 }
1502 
1504 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1505  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1506  getNumGangsSegments(), deviceType);
1507 }
1508 
1509 bool acc::KernelsOp::hasWaitOnly() {
1510  return hasWaitOnly(mlir::acc::DeviceType::None);
1511 }
1512 
1513 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1514  return hasDeviceType(getWaitOnly(), deviceType);
1515 }
1516 
1517 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1518  return getWaitValues(mlir::acc::DeviceType::None);
1519 }
1520 
1522 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1524  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1525  getHasWaitDevnum(), deviceType);
1526 }
1527 
1528 mlir::Value KernelsOp::getWaitDevnum() {
1529  return getWaitDevnum(mlir::acc::DeviceType::None);
1530 }
1531 
1532 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1533  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1534  getWaitOperandsSegments(), getHasWaitDevnum(),
1535  deviceType);
1536 }
1537 
1538 LogicalResult acc::KernelsOp::verify() {
1540  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1541  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1542  return failure();
1543 
1545  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1546  getWaitOperandsDeviceTypeAttr(), "wait")))
1547  return failure();
1548 
1549  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1550  getNumWorkersDeviceTypeAttr(),
1551  "num_workers")))
1552  return failure();
1553 
1554  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1555  getVectorLengthDeviceTypeAttr(),
1556  "vector_length")))
1557  return failure();
1558 
1559  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1560  getAsyncOperandsDeviceTypeAttr(),
1561  "async")))
1562  return failure();
1563 
1564  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1565  return failure();
1566 
1567  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1568 }
1569 
1570 //===----------------------------------------------------------------------===//
1571 // HostDataOp
1572 //===----------------------------------------------------------------------===//
1573 
1574 LogicalResult acc::HostDataOp::verify() {
1575  if (getDataClauseOperands().empty())
1576  return emitError("at least one operand must appear on the host_data "
1577  "operation");
1578 
1579  for (mlir::Value operand : getDataClauseOperands())
1580  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1581  return emitError("expect data entry operation as defining op");
1582  return success();
1583 }
1584 
1585 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1586  MLIRContext *context) {
1587  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1588 }
1589 
1590 //===----------------------------------------------------------------------===//
1591 // LoopOp
1592 //===----------------------------------------------------------------------===//
1593 
1594 static ParseResult parseGangValue(
1595  OpAsmParser &parser, llvm::StringRef keyword,
1598  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1599  bool &needCommaBetweenValues, bool &newValue) {
1600  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1601  if (parser.parseEqual())
1602  return failure();
1603  if (parser.parseOperand(operands.emplace_back()) ||
1604  parser.parseColonType(types.emplace_back()))
1605  return failure();
1606  attributes.push_back(gangArgType);
1607  needCommaBetweenValues = true;
1608  newValue = true;
1609  }
1610  return success();
1611 }
1612 
1613 static ParseResult parseGangClause(
1614  OpAsmParser &parser,
1616  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1617  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1618  mlir::ArrayAttr &gangOnlyDeviceType) {
1619  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1620  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1621  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1623  bool needCommaBetweenValues = false;
1624  bool needCommaBeforeOperands = false;
1625 
1626  if (failed(parser.parseOptionalLParen())) {
1627  // Gang only keyword
1628  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1630  gangOnlyDeviceType =
1631  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1632  return success();
1633  }
1634 
1635  // Parse gang only attributes
1636  if (succeeded(parser.parseOptionalLSquare())) {
1637  // Parse gang only attributes
1638  if (failed(parser.parseCommaSeparatedList([&]() {
1639  if (parser.parseAttribute(
1640  gangOnlyDeviceTypeAttributes.emplace_back()))
1641  return failure();
1642  return success();
1643  })))
1644  return failure();
1645  if (parser.parseRSquare())
1646  return failure();
1647  needCommaBeforeOperands = true;
1648  }
1649 
1650  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1651  mlir::acc::GangArgType::Num);
1652  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1653  mlir::acc::GangArgType::Dim);
1654  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1655  parser.getContext(), mlir::acc::GangArgType::Static);
1656 
1657  do {
1658  if (needCommaBeforeOperands) {
1659  needCommaBeforeOperands = false;
1660  continue;
1661  }
1662 
1663  if (failed(parser.parseLBrace()))
1664  return failure();
1665 
1666  int32_t crtOperandsSize = gangOperands.size();
1667  while (true) {
1668  bool newValue = false;
1669  bool needValue = false;
1670  if (needCommaBetweenValues) {
1671  if (succeeded(parser.parseOptionalComma()))
1672  needValue = true; // expect a new value after comma.
1673  else
1674  break;
1675  }
1676 
1677  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1678  gangOperands, gangOperandsType,
1679  gangArgTypeAttributes, argNum,
1680  needCommaBetweenValues, newValue)))
1681  return failure();
1682  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1683  gangOperands, gangOperandsType,
1684  gangArgTypeAttributes, argDim,
1685  needCommaBetweenValues, newValue)))
1686  return failure();
1687  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1688  gangOperands, gangOperandsType,
1689  gangArgTypeAttributes, argStatic,
1690  needCommaBetweenValues, newValue)))
1691  return failure();
1692 
1693  if (!newValue && needValue) {
1694  parser.emitError(parser.getCurrentLocation(),
1695  "new value expected after comma");
1696  return failure();
1697  }
1698 
1699  if (!newValue)
1700  break;
1701  }
1702 
1703  if (gangOperands.empty())
1704  return parser.emitError(
1705  parser.getCurrentLocation(),
1706  "expect at least one of num, dim or static values");
1707 
1708  if (failed(parser.parseRBrace()))
1709  return failure();
1710 
1711  if (succeeded(parser.parseOptionalLSquare())) {
1712  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1713  parser.parseRSquare())
1714  return failure();
1715  } else {
1716  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1718  }
1719 
1720  seg.push_back(gangOperands.size() - crtOperandsSize);
1721 
1722  } while (succeeded(parser.parseOptionalComma()));
1723 
1724  if (failed(parser.parseRParen()))
1725  return failure();
1726 
1727  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1728  gangArgTypeAttributes.end());
1729  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1730  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1731 
1733  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1734  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1735 
1736  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1737  return success();
1738 }
1739 
1741  mlir::OperandRange operands, mlir::TypeRange types,
1742  std::optional<mlir::ArrayAttr> gangArgTypes,
1743  std::optional<mlir::ArrayAttr> deviceTypes,
1744  std::optional<mlir::DenseI32ArrayAttr> segments,
1745  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1746 
1747  if (operands.begin() == operands.end() &&
1748  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1749  return;
1750  }
1751 
1752  p << "(";
1753 
1754  printDeviceTypes(p, gangOnlyDeviceTypes);
1755 
1756  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1757  hasDeviceTypeValues(deviceTypes))
1758  p << ", ";
1759 
1760  if (hasDeviceTypeValues(deviceTypes)) {
1761  unsigned opIdx = 0;
1762  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1763  p << "{";
1764  llvm::interleaveComma(
1765  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1766  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1767  (*gangArgTypes)[opIdx]);
1768  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1769  p << LoopOp::getGangNumKeyword();
1770  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1771  p << LoopOp::getGangDimKeyword();
1772  else if (gangArgTypeAttr.getValue() ==
1773  mlir::acc::GangArgType::Static)
1774  p << LoopOp::getGangStaticKeyword();
1775  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1776  ++opIdx;
1777  });
1778  p << "}";
1779  printSingleDeviceType(p, it.value());
1780  });
1781  }
1782  p << ")";
1783 }
1784 
1786  std::optional<mlir::ArrayAttr> segments,
1787  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1788  if (!segments)
1789  return false;
1790  for (auto attr : *segments) {
1791  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1792  if (deviceTypes.contains(deviceTypeAttr.getValue()))
1793  return true;
1794  deviceTypes.insert(deviceTypeAttr.getValue());
1795  }
1796  return false;
1797 }
1798 
1799 /// Check for duplicates in the DeviceType array attribute.
1800 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1801  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1802  if (!deviceTypes)
1803  return success();
1804  for (auto attr : deviceTypes) {
1805  auto deviceTypeAttr =
1806  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1807  if (!deviceTypeAttr)
1808  return failure();
1809  if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1810  return failure();
1811  crtDeviceTypes.insert(deviceTypeAttr.getValue());
1812  }
1813  return success();
1814 }
1815 
1816 LogicalResult acc::LoopOp::verify() {
1817  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1818  (getUpperbound().size() != getInclusiveUpperbound()->size()))
1819  return emitError() << "inclusiveUpperbound size is expected to be the same"
1820  << " as upperbound size";
1821 
1822  // Check collapse
1823  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1824  return emitOpError() << "collapse device_type attr must be define when"
1825  << " collapse attr is present";
1826 
1827  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1828  getCollapseAttr().getValue().size() !=
1829  getCollapseDeviceTypeAttr().getValue().size())
1830  return emitOpError() << "collapse attribute count must match collapse"
1831  << " device_type count";
1832  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1833  return emitOpError()
1834  << "duplicate device_type found in collapseDeviceType attribute";
1835 
1836  // Check gang
1837  if (!getGangOperands().empty()) {
1838  if (!getGangOperandsArgType())
1839  return emitOpError() << "gangOperandsArgType attribute must be defined"
1840  << " when gang operands are present";
1841 
1842  if (getGangOperands().size() !=
1843  getGangOperandsArgTypeAttr().getValue().size())
1844  return emitOpError() << "gangOperandsArgType attribute count must match"
1845  << " gangOperands count";
1846  }
1847  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1848  return emitOpError() << "duplicate device_type found in gang attribute";
1849 
1851  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1852  getGangOperandsDeviceTypeAttr(), "gang")))
1853  return failure();
1854 
1855  // Check worker
1856  if (failed(checkDeviceTypes(getWorkerAttr())))
1857  return emitOpError() << "duplicate device_type found in worker attribute";
1858  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1859  return emitOpError() << "duplicate device_type found in "
1860  "workerNumOperandsDeviceType attribute";
1861  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1862  getWorkerNumOperandsDeviceTypeAttr(),
1863  "worker")))
1864  return failure();
1865 
1866  // Check vector
1867  if (failed(checkDeviceTypes(getVectorAttr())))
1868  return emitOpError() << "duplicate device_type found in vector attribute";
1869  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1870  return emitOpError() << "duplicate device_type found in "
1871  "vectorOperandsDeviceType attribute";
1872  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1873  getVectorOperandsDeviceTypeAttr(),
1874  "vector")))
1875  return failure();
1876 
1878  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1879  getTileOperandsDeviceTypeAttr(), "tile")))
1880  return failure();
1881 
1882  // auto, independent and seq attribute are mutually exclusive.
1883  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1884  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1885  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1886  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1887  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1888  << "\", " << getIndependentAttrName() << ", "
1889  << getSeqAttrName()
1890  << " can be present at the same time";
1891  }
1892 
1893  // Gang, worker and vector are incompatible with seq.
1894  if (getSeqAttr()) {
1895  for (auto attr : getSeqAttr()) {
1896  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1897  if (hasVector(deviceTypeAttr.getValue()) ||
1898  getVectorValue(deviceTypeAttr.getValue()) ||
1899  hasWorker(deviceTypeAttr.getValue()) ||
1900  getWorkerValue(deviceTypeAttr.getValue()) ||
1901  hasGang(deviceTypeAttr.getValue()) ||
1902  getGangValue(mlir::acc::GangArgType::Num,
1903  deviceTypeAttr.getValue()) ||
1904  getGangValue(mlir::acc::GangArgType::Dim,
1905  deviceTypeAttr.getValue()) ||
1906  getGangValue(mlir::acc::GangArgType::Static,
1907  deviceTypeAttr.getValue()))
1908  return emitError()
1909  << "gang, worker or vector cannot appear with the seq attr";
1910  }
1911  }
1912 
1913  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1914  *this, getPrivatizations(), getPrivateOperands(), "private",
1915  "privatizations", false)))
1916  return failure();
1917 
1918  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1919  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1920  "reductions", false)))
1921  return failure();
1922 
1923  if (getCombined().has_value() &&
1924  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1925  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1926  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1927  return emitError("unexpected combined constructs attribute");
1928  }
1929 
1930  // Check non-empty body().
1931  if (getRegion().empty())
1932  return emitError("expected non-empty body.");
1933 
1934  return success();
1935 }
1936 
1937 unsigned LoopOp::getNumDataOperands() {
1938  return getReductionOperands().size() + getPrivateOperands().size();
1939 }
1940 
1941 Value LoopOp::getDataOperand(unsigned i) {
1942  unsigned numOptional =
1943  getLowerbound().size() + getUpperbound().size() + getStep().size();
1944  numOptional += getGangOperands().size();
1945  numOptional += getVectorOperands().size();
1946  numOptional += getWorkerNumOperands().size();
1947  numOptional += getTileOperands().size();
1948  numOptional += getCacheOperands().size();
1949  return getOperand(numOptional + i);
1950 }
1951 
1952 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
1953 
1954 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1955  return hasDeviceType(getAuto_(), deviceType);
1956 }
1957 
1958 bool LoopOp::hasIndependent() {
1959  return hasIndependent(mlir::acc::DeviceType::None);
1960 }
1961 
1962 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1963  return hasDeviceType(getIndependent(), deviceType);
1964 }
1965 
1966 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
1967 
1968 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1969  return hasDeviceType(getSeq(), deviceType);
1970 }
1971 
1972 mlir::Value LoopOp::getVectorValue() {
1973  return getVectorValue(mlir::acc::DeviceType::None);
1974 }
1975 
1976 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1977  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
1978  getVectorOperands(), deviceType);
1979 }
1980 
1981 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
1982 
1983 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1984  return hasDeviceType(getVector(), deviceType);
1985 }
1986 
1987 mlir::Value LoopOp::getWorkerValue() {
1988  return getWorkerValue(mlir::acc::DeviceType::None);
1989 }
1990 
1991 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1992  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
1993  getWorkerNumOperands(), deviceType);
1994 }
1995 
1996 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
1997 
1998 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
1999  return hasDeviceType(getWorker(), deviceType);
2000 }
2001 
2002 mlir::Operation::operand_range LoopOp::getTileValues() {
2003  return getTileValues(mlir::acc::DeviceType::None);
2004 }
2005 
2007 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2008  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2009  getTileOperandsSegments(), deviceType);
2010 }
2011 
2012 std::optional<int64_t> LoopOp::getCollapseValue() {
2013  return getCollapseValue(mlir::acc::DeviceType::None);
2014 }
2015 
2016 std::optional<int64_t>
2017 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2018  if (!getCollapseAttr())
2019  return std::nullopt;
2020  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2021  auto intAttr =
2022  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2023  return intAttr.getValue().getZExtValue();
2024  }
2025  return std::nullopt;
2026 }
2027 
2028 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2029  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2030 }
2031 
2032 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2033  mlir::acc::DeviceType deviceType) {
2034  if (getGangOperands().empty())
2035  return {};
2036  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2037  int32_t nbOperandsBefore = 0;
2038  for (unsigned i = 0; i < *pos; ++i)
2039  nbOperandsBefore += (*getGangOperandsSegments())[i];
2041  getGangOperands()
2042  .drop_front(nbOperandsBefore)
2043  .take_front((*getGangOperandsSegments())[*pos]);
2044 
2045  int32_t argTypeIdx = nbOperandsBefore;
2046  for (auto value : values) {
2047  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2048  (*getGangOperandsArgType())[argTypeIdx]);
2049  if (gangArgTypeAttr.getValue() == gangArgType)
2050  return value;
2051  ++argTypeIdx;
2052  }
2053  }
2054  return {};
2055 }
2056 
2057 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2058 
2059 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2060  return hasDeviceType(getGang(), deviceType);
2061 }
2062 
2063 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2064  return {&getRegion()};
2065 }
2066 
2067 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2068 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2069 /// `(` ssa-id-and-type-list `)`
2070 /// region
2071 ParseResult
2074  SmallVectorImpl<Type> &lowerboundType,
2076  SmallVectorImpl<Type> &upperboundType,
2078  SmallVectorImpl<Type> &stepType) {
2079 
2080  SmallVector<OpAsmParser::Argument> inductionVars;
2081  if (succeeded(
2082  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2083  if (parser.parseLParen() ||
2084  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2085  /*allowType=*/true) ||
2086  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2087  parser.parseOperandList(lowerbound, inductionVars.size(),
2089  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2090  parser.parseKeyword("to") || parser.parseLParen() ||
2091  parser.parseOperandList(upperbound, inductionVars.size(),
2093  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2094  parser.parseKeyword("step") || parser.parseLParen() ||
2095  parser.parseOperandList(step, inductionVars.size(),
2097  parser.parseColonTypeList(stepType) || parser.parseRParen())
2098  return failure();
2099  }
2100  return parser.parseRegion(region, inductionVars);
2101 }
2102 
2104  ValueRange lowerbound, TypeRange lowerboundType,
2105  ValueRange upperbound, TypeRange upperboundType,
2106  ValueRange steps, TypeRange stepType) {
2107  ValueRange regionArgs = region.front().getArguments();
2108  if (!regionArgs.empty()) {
2109  p << acc::LoopOp::getControlKeyword() << "(";
2110  llvm::interleaveComma(regionArgs, p,
2111  [&p](Value v) { p << v << " : " << v.getType(); });
2112  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2113  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2114  << " : " << stepType << ") ";
2115  }
2116  p.printRegion(region, /*printEntryBlockArgs=*/false);
2117 }
2118 
2119 //===----------------------------------------------------------------------===//
2120 // DataOp
2121 //===----------------------------------------------------------------------===//
2122 
2123 LogicalResult acc::DataOp::verify() {
2124  // 2.6.5. Data Construct restriction
2125  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2126  // attach, or default clause must appear on a data construct.
2127  if (getOperands().empty() && !getDefaultAttr())
2128  return emitError("at least one operand or the default attribute "
2129  "must appear on the data operation");
2130 
2131  for (mlir::Value operand : getDataClauseOperands())
2132  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2133  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2134  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2135  operand.getDefiningOp()))
2136  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2137  "as defining op");
2138 
2139  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2140  return failure();
2141 
2142  return success();
2143 }
2144 
2145 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2146 
2147 Value DataOp::getDataOperand(unsigned i) {
2148  unsigned numOptional = getIfCond() ? 1 : 0;
2149  numOptional += getAsyncOperands().size() ? 1 : 0;
2150  numOptional += getWaitOperands().size();
2151  return getOperand(numOptional + i);
2152 }
2153 
2154 bool acc::DataOp::hasAsyncOnly() {
2155  return hasAsyncOnly(mlir::acc::DeviceType::None);
2156 }
2157 
2158 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2159  return hasDeviceType(getAsyncOnly(), deviceType);
2160 }
2161 
2162 mlir::Value DataOp::getAsyncValue() {
2163  return getAsyncValue(mlir::acc::DeviceType::None);
2164 }
2165 
2166 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2168  getAsyncOperands(), deviceType);
2169 }
2170 
2171 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2172 
2173 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2174  return hasDeviceType(getWaitOnly(), deviceType);
2175 }
2176 
2177 mlir::Operation::operand_range DataOp::getWaitValues() {
2178  return getWaitValues(mlir::acc::DeviceType::None);
2179 }
2180 
2182 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2184  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2185  getHasWaitDevnum(), deviceType);
2186 }
2187 
2188 mlir::Value DataOp::getWaitDevnum() {
2189  return getWaitDevnum(mlir::acc::DeviceType::None);
2190 }
2191 
2192 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2193  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2194  getWaitOperandsSegments(), getHasWaitDevnum(),
2195  deviceType);
2196 }
2197 
2198 //===----------------------------------------------------------------------===//
2199 // ExitDataOp
2200 //===----------------------------------------------------------------------===//
2201 
2202 LogicalResult acc::ExitDataOp::verify() {
2203  // 2.6.6. Data Exit Directive restriction
2204  // At least one copyout, delete, or detach clause must appear on an exit data
2205  // directive.
2206  if (getDataClauseOperands().empty())
2207  return emitError("at least one operand must be present in dataOperands on "
2208  "the exit data operation");
2209 
2210  // The async attribute represent the async clause without value. Therefore the
2211  // attribute and operand cannot appear at the same time.
2212  if (getAsyncOperand() && getAsync())
2213  return emitError("async attribute cannot appear with asyncOperand");
2214 
2215  // The wait attribute represent the wait clause without values. Therefore the
2216  // attribute and operands cannot appear at the same time.
2217  if (!getWaitOperands().empty() && getWait())
2218  return emitError("wait attribute cannot appear with waitOperands");
2219 
2220  if (getWaitDevnum() && getWaitOperands().empty())
2221  return emitError("wait_devnum cannot appear without waitOperands");
2222 
2223  return success();
2224 }
2225 
2226 unsigned ExitDataOp::getNumDataOperands() {
2227  return getDataClauseOperands().size();
2228 }
2229 
2230 Value ExitDataOp::getDataOperand(unsigned i) {
2231  unsigned numOptional = getIfCond() ? 1 : 0;
2232  numOptional += getAsyncOperand() ? 1 : 0;
2233  numOptional += getWaitDevnum() ? 1 : 0;
2234  return getOperand(getWaitOperands().size() + numOptional + i);
2235 }
2236 
2237 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2238  MLIRContext *context) {
2239  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2240 }
2241 
2242 //===----------------------------------------------------------------------===//
2243 // EnterDataOp
2244 //===----------------------------------------------------------------------===//
2245 
2246 LogicalResult acc::EnterDataOp::verify() {
2247  // 2.6.6. Data Enter Directive restriction
2248  // At least one copyin, create, or attach clause must appear on an enter data
2249  // directive.
2250  if (getDataClauseOperands().empty())
2251  return emitError("at least one operand must be present in dataOperands on "
2252  "the enter data operation");
2253 
2254  // The async attribute represent the async clause without value. Therefore the
2255  // attribute and operand cannot appear at the same time.
2256  if (getAsyncOperand() && getAsync())
2257  return emitError("async attribute cannot appear with asyncOperand");
2258 
2259  // The wait attribute represent the wait clause without values. Therefore the
2260  // attribute and operands cannot appear at the same time.
2261  if (!getWaitOperands().empty() && getWait())
2262  return emitError("wait attribute cannot appear with waitOperands");
2263 
2264  if (getWaitDevnum() && getWaitOperands().empty())
2265  return emitError("wait_devnum cannot appear without waitOperands");
2266 
2267  for (mlir::Value operand : getDataClauseOperands())
2268  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2269  operand.getDefiningOp()))
2270  return emitError("expect data entry operation as defining op");
2271 
2272  return success();
2273 }
2274 
2275 unsigned EnterDataOp::getNumDataOperands() {
2276  return getDataClauseOperands().size();
2277 }
2278 
2279 Value EnterDataOp::getDataOperand(unsigned i) {
2280  unsigned numOptional = getIfCond() ? 1 : 0;
2281  numOptional += getAsyncOperand() ? 1 : 0;
2282  numOptional += getWaitDevnum() ? 1 : 0;
2283  return getOperand(getWaitOperands().size() + numOptional + i);
2284 }
2285 
2286 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287  MLIRContext *context) {
2288  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2289 }
2290 
2291 //===----------------------------------------------------------------------===//
2292 // AtomicReadOp
2293 //===----------------------------------------------------------------------===//
2294 
2295 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2296 
2297 //===----------------------------------------------------------------------===//
2298 // AtomicWriteOp
2299 //===----------------------------------------------------------------------===//
2300 
2301 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // AtomicUpdateOp
2305 //===----------------------------------------------------------------------===//
2306 
2307 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2308  PatternRewriter &rewriter) {
2309  if (op.isNoOp()) {
2310  rewriter.eraseOp(op);
2311  return success();
2312  }
2313 
2314  if (Value writeVal = op.getWriteOpVal()) {
2315  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2316  return success();
2317  }
2318 
2319  return failure();
2320 }
2321 
2322 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2323 
2324 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2325 
2326 //===----------------------------------------------------------------------===//
2327 // AtomicCaptureOp
2328 //===----------------------------------------------------------------------===//
2329 
2330 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2331  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2332  return op;
2333  return dyn_cast<AtomicReadOp>(getSecondOp());
2334 }
2335 
2336 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2337  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2338  return op;
2339  return dyn_cast<AtomicWriteOp>(getSecondOp());
2340 }
2341 
2342 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2343  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2344  return op;
2345  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2346 }
2347 
2348 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // DeclareEnterOp
2352 //===----------------------------------------------------------------------===//
2353 
2354 template <typename Op>
2355 static LogicalResult
2357  bool requireAtLeastOneOperand = true) {
2358  if (operands.empty() && requireAtLeastOneOperand)
2359  return emitError(
2360  op->getLoc(),
2361  "at least one operand must appear on the declare operation");
2362 
2363  for (mlir::Value operand : operands) {
2364  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2365  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2366  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2367  operand.getDefiningOp()))
2368  return op.emitError(
2369  "expect valid declare data entry operation or acc.getdeviceptr "
2370  "as defining op");
2371 
2372  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2373  assert(varPtr && "declare operands can only be data entry operations which "
2374  "must have varPtr");
2375  std::optional<mlir::acc::DataClause> dataClauseOptional{
2376  getDataClause(operand.getDefiningOp())};
2377  assert(dataClauseOptional.has_value() &&
2378  "declare operands can only be data entry operations which must have "
2379  "dataClause");
2380 
2381  // If varPtr has no defining op - there is nothing to check further.
2382  if (!varPtr.getDefiningOp())
2383  continue;
2384 
2385  // Check that the varPtr has a declare attribute.
2386  auto declareAttribute{
2387  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2388  if (!declareAttribute)
2389  return op.emitError(
2390  "expect declare attribute on variable in declare operation");
2391 
2392  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2393  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2394  return op.emitError(
2395  "expect matching declare attribute on variable in declare operation");
2396 
2397  // If the variable is marked with implicit attribute, the matching declare
2398  // data action must also be marked implicit. The reverse is not checked
2399  // since implicit data action may be inserted to do actions like updating
2400  // device copy, in which case the variable is not necessarily implicitly
2401  // declare'd.
2402  if (declAttr.getImplicit() &&
2403  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2404  return op.emitError(
2405  "implicitness must match between declare op and flag on variable");
2406  }
2407 
2408  return success();
2409 }
2410 
2411 LogicalResult acc::DeclareEnterOp::verify() {
2412  return checkDeclareOperands(*this, this->getDataClauseOperands());
2413 }
2414 
2415 //===----------------------------------------------------------------------===//
2416 // DeclareExitOp
2417 //===----------------------------------------------------------------------===//
2418 
2419 LogicalResult acc::DeclareExitOp::verify() {
2420  if (getToken())
2421  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2422  /*requireAtLeastOneOperand=*/false);
2423  return checkDeclareOperands(*this, this->getDataClauseOperands());
2424 }
2425 
2426 //===----------------------------------------------------------------------===//
2427 // DeclareOp
2428 //===----------------------------------------------------------------------===//
2429 
2430 LogicalResult acc::DeclareOp::verify() {
2431  return checkDeclareOperands(*this, this->getDataClauseOperands());
2432 }
2433 
2434 //===----------------------------------------------------------------------===//
2435 // RoutineOp
2436 //===----------------------------------------------------------------------===//
2437 
2438 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2439  acc::DeviceType dtype) {
2440  unsigned parallelism = 0;
2441  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2442  parallelism += op.hasWorker(dtype) ? 1 : 0;
2443  parallelism += op.hasVector(dtype) ? 1 : 0;
2444  parallelism += op.hasSeq(dtype) ? 1 : 0;
2445  return parallelism;
2446 }
2447 
2448 LogicalResult acc::RoutineOp::verify() {
2449  unsigned baseParallelism =
2451 
2452  if (baseParallelism > 1)
2453  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2454  "be present at the same time";
2455 
2456  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2457  ++dtypeInt) {
2458  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2459  if (dtype == acc::DeviceType::None)
2460  continue;
2461  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2462 
2463  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2464  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2465  "be present at the same time";
2466  }
2467 
2468  return success();
2469 }
2470 
2471 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2472  mlir::ArrayAttr &deviceTypes) {
2473  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2474  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2475 
2476  if (failed(parser.parseCommaSeparatedList([&]() {
2477  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2478  return failure();
2479  if (failed(parser.parseOptionalLSquare())) {
2480  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2481  parser.getContext(), mlir::acc::DeviceType::None));
2482  } else {
2483  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2484  parser.parseRSquare())
2485  return failure();
2486  }
2487  return success();
2488  })))
2489  return failure();
2490 
2491  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2492  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2493 
2494  return success();
2495 }
2496 
2498  std::optional<mlir::ArrayAttr> bindName,
2499  std::optional<mlir::ArrayAttr> deviceTypes) {
2500  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2501  [&](const auto &pair) {
2502  p << std::get<0>(pair);
2503  printSingleDeviceType(p, std::get<1>(pair));
2504  });
2505 }
2506 
2507 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2508  mlir::ArrayAttr &gang,
2509  mlir::ArrayAttr &gangDim,
2510  mlir::ArrayAttr &gangDimDeviceTypes) {
2511 
2512  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2513  gangDimDeviceTypeAttrs;
2514  bool needCommaBeforeOperands = false;
2515 
2516  // Gang keyword only
2517  if (failed(parser.parseOptionalLParen())) {
2518  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2520  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2521  return success();
2522  }
2523 
2524  // Parse keyword only attributes
2525  if (succeeded(parser.parseOptionalLSquare())) {
2526  if (failed(parser.parseCommaSeparatedList([&]() {
2527  if (parser.parseAttribute(gangAttrs.emplace_back()))
2528  return failure();
2529  return success();
2530  })))
2531  return failure();
2532  if (parser.parseRSquare())
2533  return failure();
2534  needCommaBeforeOperands = true;
2535  }
2536 
2537  if (needCommaBeforeOperands && failed(parser.parseComma()))
2538  return failure();
2539 
2540  if (failed(parser.parseCommaSeparatedList([&]() {
2541  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2542  parser.parseColon() ||
2543  parser.parseAttribute(gangDimAttrs.emplace_back()))
2544  return failure();
2545  if (succeeded(parser.parseOptionalLSquare())) {
2546  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2547  parser.parseRSquare())
2548  return failure();
2549  } else {
2550  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2551  parser.getContext(), mlir::acc::DeviceType::None));
2552  }
2553  return success();
2554  })))
2555  return failure();
2556 
2557  if (failed(parser.parseRParen()))
2558  return failure();
2559 
2560  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2561  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2562  gangDimDeviceTypes =
2563  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2564 
2565  return success();
2566 }
2567 
2569  std::optional<mlir::ArrayAttr> gang,
2570  std::optional<mlir::ArrayAttr> gangDim,
2571  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2572 
2573  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2574  gang->size() == 1) {
2575  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2576  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2577  return;
2578  }
2579 
2580  p << "(";
2581 
2582  printDeviceTypes(p, gang);
2583 
2584  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2585  p << ", ";
2586 
2587  if (hasDeviceTypeValues(gangDimDeviceTypes))
2588  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2589  [&](const auto &pair) {
2590  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2591  p << std::get<0>(pair);
2592  printSingleDeviceType(p, std::get<1>(pair));
2593  });
2594 
2595  p << ")";
2596 }
2597 
2598 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2599  mlir::ArrayAttr &deviceTypes) {
2601  // Keyword only
2602  if (failed(parser.parseOptionalLParen())) {
2603  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2605  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2606  return success();
2607  }
2608 
2609  // Parse device type attributes
2610  if (succeeded(parser.parseOptionalLSquare())) {
2611  if (failed(parser.parseCommaSeparatedList([&]() {
2612  if (parser.parseAttribute(attributes.emplace_back()))
2613  return failure();
2614  return success();
2615  })))
2616  return failure();
2617  if (parser.parseRSquare() || parser.parseRParen())
2618  return failure();
2619  }
2620  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2621  return success();
2622 }
2623 
2624 static void
2626  std::optional<mlir::ArrayAttr> deviceTypes) {
2627 
2628  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2629  auto deviceTypeAttr =
2630  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2631  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2632  return;
2633  }
2634 
2635  if (!hasDeviceTypeValues(deviceTypes))
2636  return;
2637 
2638  p << "([";
2639  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2640  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2641  p << dTypeAttr;
2642  });
2643  p << "])";
2644 }
2645 
2646 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2647 
2648 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2649  return hasDeviceType(getWorker(), deviceType);
2650 }
2651 
2652 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2653 
2654 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2655  return hasDeviceType(getVector(), deviceType);
2656 }
2657 
2658 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2659 
2660 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2661  return hasDeviceType(getSeq(), deviceType);
2662 }
2663 
2664 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2665  return getBindNameValue(mlir::acc::DeviceType::None);
2666 }
2667 
2668 std::optional<llvm::StringRef>
2669 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2670  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2671  return std::nullopt;
2672  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2673  auto attr = (*getBindName())[*pos];
2674  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2675  return stringAttr.getValue();
2676  }
2677  return std::nullopt;
2678 }
2679 
2680 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2681 
2682 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2683  return hasDeviceType(getGang(), deviceType);
2684 }
2685 
2686 std::optional<int64_t> RoutineOp::getGangDimValue() {
2687  return getGangDimValue(mlir::acc::DeviceType::None);
2688 }
2689 
2690 std::optional<int64_t>
2691 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2692  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2693  return std::nullopt;
2694  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2695  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2696  return intAttr.getInt();
2697  }
2698  return std::nullopt;
2699 }
2700 
2701 //===----------------------------------------------------------------------===//
2702 // InitOp
2703 //===----------------------------------------------------------------------===//
2704 
2705 LogicalResult acc::InitOp::verify() {
2706  Operation *currOp = *this;
2707  while ((currOp = currOp->getParentOp()))
2708  if (isComputeOperation(currOp))
2709  return emitOpError("cannot be nested in a compute operation");
2710  return success();
2711 }
2712 
2713 //===----------------------------------------------------------------------===//
2714 // ShutdownOp
2715 //===----------------------------------------------------------------------===//
2716 
2717 LogicalResult acc::ShutdownOp::verify() {
2718  Operation *currOp = *this;
2719  while ((currOp = currOp->getParentOp()))
2720  if (isComputeOperation(currOp))
2721  return emitOpError("cannot be nested in a compute operation");
2722  return success();
2723 }
2724 
2725 //===----------------------------------------------------------------------===//
2726 // SetOp
2727 //===----------------------------------------------------------------------===//
2728 
2729 LogicalResult acc::SetOp::verify() {
2730  Operation *currOp = *this;
2731  while ((currOp = currOp->getParentOp()))
2732  if (isComputeOperation(currOp))
2733  return emitOpError("cannot be nested in a compute operation");
2734  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2735  return emitOpError("at least one default_async, device_num, or device_type "
2736  "operand must appear");
2737  return success();
2738 }
2739 
2740 //===----------------------------------------------------------------------===//
2741 // UpdateOp
2742 //===----------------------------------------------------------------------===//
2743 
2744 LogicalResult acc::UpdateOp::verify() {
2745  // At least one of host or device should have a value.
2746  if (getDataClauseOperands().empty())
2747  return emitError("at least one value must be present in dataOperands");
2748 
2749  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2750  getAsyncOperandsDeviceTypeAttr(),
2751  "async")))
2752  return failure();
2753 
2755  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2756  getWaitOperandsDeviceTypeAttr(), "wait")))
2757  return failure();
2758 
2759  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2760  return failure();
2761 
2762  for (mlir::Value operand : getDataClauseOperands())
2763  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2764  operand.getDefiningOp()))
2765  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2766  "as defining op");
2767 
2768  return success();
2769 }
2770 
2771 unsigned UpdateOp::getNumDataOperands() {
2772  return getDataClauseOperands().size();
2773 }
2774 
2775 Value UpdateOp::getDataOperand(unsigned i) {
2776  unsigned numOptional = getAsyncOperands().size();
2777  numOptional += getIfCond() ? 1 : 0;
2778  return getOperand(getWaitOperands().size() + numOptional + i);
2779 }
2780 
2781 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2782  MLIRContext *context) {
2783  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2784 }
2785 
2786 bool UpdateOp::hasAsyncOnly() {
2787  return hasAsyncOnly(mlir::acc::DeviceType::None);
2788 }
2789 
2790 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2791  return hasDeviceType(getAsync(), deviceType);
2792 }
2793 
2794 mlir::Value UpdateOp::getAsyncValue() {
2795  return getAsyncValue(mlir::acc::DeviceType::None);
2796 }
2797 
2798 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2800  return {};
2801 
2802  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2803  return getAsyncOperands()[*pos];
2804 
2805  return {};
2806 }
2807 
2808 bool UpdateOp::hasWaitOnly() {
2809  return hasWaitOnly(mlir::acc::DeviceType::None);
2810 }
2811 
2812 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2813  return hasDeviceType(getWaitOnly(), deviceType);
2814 }
2815 
2816 mlir::Operation::operand_range UpdateOp::getWaitValues() {
2817  return getWaitValues(mlir::acc::DeviceType::None);
2818 }
2819 
2821 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2823  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2824  getHasWaitDevnum(), deviceType);
2825 }
2826 
2827 mlir::Value UpdateOp::getWaitDevnum() {
2828  return getWaitDevnum(mlir::acc::DeviceType::None);
2829 }
2830 
2831 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2832  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2833  getWaitOperandsSegments(), getHasWaitDevnum(),
2834  deviceType);
2835 }
2836 
2837 //===----------------------------------------------------------------------===//
2838 // WaitOp
2839 //===----------------------------------------------------------------------===//
2840 
2841 LogicalResult acc::WaitOp::verify() {
2842  // The async attribute represent the async clause without value. Therefore the
2843  // attribute and operand cannot appear at the same time.
2844  if (getAsyncOperand() && getAsync())
2845  return emitError("async attribute cannot appear with asyncOperand");
2846 
2847  if (getWaitDevnum() && getWaitOperands().empty())
2848  return emitError("wait_devnum cannot appear without waitOperands");
2849 
2850  return success();
2851 }
2852 
2853 #define GET_OP_CLASSES
2854 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2855 
2856 #define GET_ATTRDEF_CLASSES
2857 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2858 
2859 #define GET_TYPEDEF_CLASSES
2860 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2861 
2862 //===----------------------------------------------------------------------===//
2863 // acc dialect utilities
2864 //===----------------------------------------------------------------------===//
2865 
2867  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2868  .Case<ACC_DATA_ENTRY_OPS>(
2869  [&](auto entry) { return entry.getVarPtr(); })
2870  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2871  [&](auto exit) { return exit.getVarPtr(); })
2872  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2873  return varPtr;
2874 }
2875 
2877  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2879  [&](auto dataClause) { return dataClause.getAccPtr(); })
2880  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2881  return accPtr;
2882 }
2883 
2885  auto varPtrPtr{
2887  .Case<ACC_DATA_ENTRY_OPS>(
2888  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2889  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2890  return varPtrPtr;
2891 }
2892 
2897  accDataClauseOp)
2898  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2900  dataClause.getBounds().begin(), dataClause.getBounds().end());
2901  })
2902  .Default([&](mlir::Operation *) {
2904  })};
2905  return bounds;
2906 }
2907 
2911  accDataClauseOp)
2912  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2914  dataClause.getAsyncOperands().begin(),
2915  dataClause.getAsyncOperands().end());
2916  })
2917  .Default([&](mlir::Operation *) {
2919  });
2920 }
2921 
2922 mlir::ArrayAttr
2925  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2926  return dataClause.getAsyncOperandsDeviceTypeAttr();
2927  })
2928  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2929 }
2930 
2931 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
2934  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
2935  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2936 }
2937 
2938 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2939  auto name{
2941  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2942  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2943  return {};
2944  })};
2945  return name;
2946 }
2947 
2948 std::optional<mlir::acc::DataClause>
2950  auto dataClause{
2952  accDataEntryOp)
2953  .Case<ACC_DATA_ENTRY_OPS>(
2954  [&](auto entry) { return entry.getDataClause(); })
2955  .Default([&](mlir::Operation *) { return std::nullopt; })};
2956  return dataClause;
2957 }
2958 
2960  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
2961  .Case<ACC_DATA_ENTRY_OPS>(
2962  [&](auto entry) { return entry.getImplicit(); })
2963  .Default([&](mlir::Operation *) { return false; })};
2964  return implicit;
2965 }
2966 
2968  auto dataOperands{
2971  [&](auto entry) { return entry.getDataClauseOperands(); })
2972  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
2973  return dataOperands;
2974 }
2975 
2978  auto dataOperands{
2981  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
2982  .Default([&](mlir::Operation *) { return nullptr; })};
2983  return dataOperands;
2984 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2112
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2568
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:445
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1785
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:747
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:1800
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:459
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1147
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2471
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1158
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1063
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:76
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2625
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1594
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1312
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2356
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:96
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2072
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:676
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1192
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:822
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:120
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:933
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2103
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2598
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2507
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1046
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1219
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2497
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1000
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:660
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:152
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1613
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:535
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:977
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:107
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:691
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1292
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:82
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1740
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:136
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1230
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:172
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:757
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2438
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:983
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1338
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:640
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:66
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:42
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:50
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:210
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2866
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:2949
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:2977
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:2894
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:2967
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2938
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:2959
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:2909
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:2884
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2876
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2931
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:139
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2923
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.