MLIR  20.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 using namespace acc;
24 
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
29 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
30 
31 namespace {
32 struct MemRefPointerLikeModel
33  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
34  MemRefType> {
35  Type getElementType(Type pointer) const {
36  return llvm::cast<MemRefType>(pointer).getElementType();
37  }
38 };
39 
40 struct LLVMPointerPointerLikeModel
41  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
42  LLVM::LLVMPointerType> {
43  Type getElementType(Type pointer) const { return Type(); }
44 };
45 } // namespace
46 
47 //===----------------------------------------------------------------------===//
48 // OpenACC operations
49 //===----------------------------------------------------------------------===//
50 
51 void OpenACCDialect::initialize() {
52  addOperations<
53 #define GET_OP_LIST
54 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
55  >();
56  addAttributes<
57 #define GET_ATTRDEF_LIST
58 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
59  >();
60  addTypes<
61 #define GET_TYPEDEF_LIST
62 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
63  >();
64 
65  // By attaching interfaces here, we make the OpenACC dialect dependent on
66  // the other dialects. This is probably better than having dialects like LLVM
67  // and memref be dependent on OpenACC.
68  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
69  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
70  *getContext());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // device_type support helpers
75 //===----------------------------------------------------------------------===//
76 
77 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
78  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
79  return true;
80  return false;
81 }
82 
83 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
84  mlir::acc::DeviceType deviceType) {
85  if (!hasDeviceTypeValues(arrayAttr))
86  return false;
87 
88  for (auto attr : *arrayAttr) {
89  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
90  if (deviceTypeAttr.getValue() == deviceType)
91  return true;
92  }
93 
94  return false;
95 }
96 
98  std::optional<mlir::ArrayAttr> deviceTypes) {
99  if (!hasDeviceTypeValues(deviceTypes))
100  return;
101 
102  p << "[";
103  llvm::interleaveComma(*deviceTypes, p,
104  [&](mlir::Attribute attr) { p << attr; });
105  p << "]";
106 }
107 
108 static std::optional<unsigned> findSegment(ArrayAttr segments,
109  mlir::acc::DeviceType deviceType) {
110  unsigned segmentIdx = 0;
111  for (auto attr : segments) {
112  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
113  if (deviceTypeAttr.getValue() == deviceType)
114  return std::make_optional(segmentIdx);
115  ++segmentIdx;
116  }
117  return std::nullopt;
118 }
119 
121 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
123  std::optional<llvm::ArrayRef<int32_t>> segments,
124  mlir::acc::DeviceType deviceType) {
125  if (!arrayAttr)
126  return range.take_front(0);
127  if (auto pos = findSegment(*arrayAttr, deviceType)) {
128  int32_t nbOperandsBefore = 0;
129  for (unsigned i = 0; i < *pos; ++i)
130  nbOperandsBefore += (*segments)[i];
131  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
132  }
133  return range.take_front(0);
134 }
135 
136 static mlir::Value
137 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
139  std::optional<llvm::ArrayRef<int32_t>> segments,
140  std::optional<mlir::ArrayAttr> hasWaitDevnum,
141  mlir::acc::DeviceType deviceType) {
142  if (!hasDeviceTypeValues(deviceTypeAttr))
143  return {};
144  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
145  if (hasWaitDevnum->getValue()[*pos])
146  return getValuesFromSegments(deviceTypeAttr, operands, segments,
147  deviceType)
148  .front();
149  return {};
150 }
151 
153 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
155  std::optional<llvm::ArrayRef<int32_t>> segments,
156  std::optional<mlir::ArrayAttr> hasWaitDevnum,
157  mlir::acc::DeviceType deviceType) {
158  auto range =
159  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
160  if (range.empty())
161  return range;
162  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
163  if (hasWaitDevnum && *hasWaitDevnum) {
164  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
165  if (boolAttr.getValue())
166  return range.drop_front(1); // first value is devnum
167  }
168  }
169  return range;
170 }
171 
172 template <typename Op>
173 static LogicalResult checkWaitAndAsyncConflict(Op op) {
174  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
175  ++dtypeInt) {
176  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
177 
178  // The async attribute represent the async clause without value. Therefore
179  // the attribute and operand cannot appear at the same time.
180  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
181  op.hasAsyncOnly(dtype))
182  return op.emitError("async attribute cannot appear with asyncOperand");
183 
184  // The wait attribute represent the wait clause without values. Therefore
185  // the attribute and operands cannot appear at the same time.
186  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
187  op.hasWaitOnly(dtype))
188  return op.emitError("wait attribute cannot appear with waitOperands");
189  }
190  return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // DataBoundsOp
195 //===----------------------------------------------------------------------===//
196 LogicalResult acc::DataBoundsOp::verify() {
197  auto extent = getExtent();
198  auto upperbound = getUpperbound();
199  if (!extent && !upperbound)
200  return emitError("expected extent or upperbound.");
201  return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // PrivateOp
206 //===----------------------------------------------------------------------===//
207 LogicalResult acc::PrivateOp::verify() {
208  if (getDataClause() != acc::DataClause::acc_private)
209  return emitError(
210  "data clause associated with private operation must match its intent");
211  return success();
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // FirstprivateOp
216 //===----------------------------------------------------------------------===//
217 LogicalResult acc::FirstprivateOp::verify() {
218  if (getDataClause() != acc::DataClause::acc_firstprivate)
219  return emitError("data clause associated with firstprivate operation must "
220  "match its intent");
221  return success();
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // ReductionOp
226 //===----------------------------------------------------------------------===//
227 LogicalResult acc::ReductionOp::verify() {
228  if (getDataClause() != acc::DataClause::acc_reduction)
229  return emitError("data clause associated with reduction operation must "
230  "match its intent");
231  return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // DevicePtrOp
236 //===----------------------------------------------------------------------===//
237 LogicalResult acc::DevicePtrOp::verify() {
238  if (getDataClause() != acc::DataClause::acc_deviceptr)
239  return emitError("data clause associated with deviceptr operation must "
240  "match its intent");
241  return success();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // PresentOp
246 //===----------------------------------------------------------------------===//
247 LogicalResult acc::PresentOp::verify() {
248  if (getDataClause() != acc::DataClause::acc_present)
249  return emitError(
250  "data clause associated with present operation must match its intent");
251  return success();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // CopyinOp
256 //===----------------------------------------------------------------------===//
257 LogicalResult acc::CopyinOp::verify() {
258  // Test for all clauses this operation can be decomposed from:
259  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
260  getDataClause() != acc::DataClause::acc_copyin_readonly &&
261  getDataClause() != acc::DataClause::acc_copy &&
262  getDataClause() != acc::DataClause::acc_reduction)
263  return emitError(
264  "data clause associated with copyin operation must match its intent"
265  " or specify original clause this operation was decomposed from");
266  return success();
267 }
268 
269 bool acc::CopyinOp::isCopyinReadonly() {
270  return getDataClause() == acc::DataClause::acc_copyin_readonly;
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // CreateOp
275 //===----------------------------------------------------------------------===//
276 LogicalResult acc::CreateOp::verify() {
277  // Test for all clauses this operation can be decomposed from:
278  if (getDataClause() != acc::DataClause::acc_create &&
279  getDataClause() != acc::DataClause::acc_create_zero &&
280  getDataClause() != acc::DataClause::acc_copyout &&
281  getDataClause() != acc::DataClause::acc_copyout_zero)
282  return emitError(
283  "data clause associated with create operation must match its intent"
284  " or specify original clause this operation was decomposed from");
285  return success();
286 }
287 
288 bool acc::CreateOp::isCreateZero() {
289  // The zero modifier is encoded in the data clause.
290  return getDataClause() == acc::DataClause::acc_create_zero ||
291  getDataClause() == acc::DataClause::acc_copyout_zero;
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // NoCreateOp
296 //===----------------------------------------------------------------------===//
297 LogicalResult acc::NoCreateOp::verify() {
298  if (getDataClause() != acc::DataClause::acc_no_create)
299  return emitError("data clause associated with no_create operation must "
300  "match its intent");
301  return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // AttachOp
306 //===----------------------------------------------------------------------===//
307 LogicalResult acc::AttachOp::verify() {
308  if (getDataClause() != acc::DataClause::acc_attach)
309  return emitError(
310  "data clause associated with attach operation must match its intent");
311  return success();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // DeclareDeviceResidentOp
316 //===----------------------------------------------------------------------===//
317 
318 LogicalResult acc::DeclareDeviceResidentOp::verify() {
319  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
320  return emitError("data clause associated with device_resident operation "
321  "must match its intent");
322  return success();
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // DeclareLinkOp
327 //===----------------------------------------------------------------------===//
328 
329 LogicalResult acc::DeclareLinkOp::verify() {
330  if (getDataClause() != acc::DataClause::acc_declare_link)
331  return emitError(
332  "data clause associated with link operation must match its intent");
333  return success();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // CopyoutOp
338 //===----------------------------------------------------------------------===//
339 LogicalResult acc::CopyoutOp::verify() {
340  // Test for all clauses this operation can be decomposed from:
341  if (getDataClause() != acc::DataClause::acc_copyout &&
342  getDataClause() != acc::DataClause::acc_copyout_zero &&
343  getDataClause() != acc::DataClause::acc_copy &&
344  getDataClause() != acc::DataClause::acc_reduction)
345  return emitError(
346  "data clause associated with copyout operation must match its intent"
347  " or specify original clause this operation was decomposed from");
348  if (!getVarPtr() || !getAccPtr())
349  return emitError("must have both host and device pointers");
350  return success();
351 }
352 
353 bool acc::CopyoutOp::isCopyoutZero() {
354  return getDataClause() == acc::DataClause::acc_copyout_zero;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // DeleteOp
359 //===----------------------------------------------------------------------===//
360 LogicalResult acc::DeleteOp::verify() {
361  // Test for all clauses this operation can be decomposed from:
362  if (getDataClause() != acc::DataClause::acc_delete &&
363  getDataClause() != acc::DataClause::acc_create &&
364  getDataClause() != acc::DataClause::acc_create_zero &&
365  getDataClause() != acc::DataClause::acc_copyin &&
366  getDataClause() != acc::DataClause::acc_copyin_readonly &&
367  getDataClause() != acc::DataClause::acc_present &&
368  getDataClause() != acc::DataClause::acc_declare_device_resident &&
369  getDataClause() != acc::DataClause::acc_declare_link)
370  return emitError(
371  "data clause associated with delete operation must match its intent"
372  " or specify original clause this operation was decomposed from");
373  if (!getAccPtr())
374  return emitError("must have device pointer");
375  return success();
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // DetachOp
380 //===----------------------------------------------------------------------===//
381 LogicalResult acc::DetachOp::verify() {
382  // Test for all clauses this operation can be decomposed from:
383  if (getDataClause() != acc::DataClause::acc_detach &&
384  getDataClause() != acc::DataClause::acc_attach)
385  return emitError(
386  "data clause associated with detach operation must match its intent"
387  " or specify original clause this operation was decomposed from");
388  if (!getAccPtr())
389  return emitError("must have device pointer");
390  return success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // HostOp
395 //===----------------------------------------------------------------------===//
396 LogicalResult acc::UpdateHostOp::verify() {
397  // Test for all clauses this operation can be decomposed from:
398  if (getDataClause() != acc::DataClause::acc_update_host &&
399  getDataClause() != acc::DataClause::acc_update_self)
400  return emitError(
401  "data clause associated with host operation must match its intent"
402  " or specify original clause this operation was decomposed from");
403  if (!getVarPtr() || !getAccPtr())
404  return emitError("must have both host and device pointers");
405  return success();
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // DeviceOp
410 //===----------------------------------------------------------------------===//
411 LogicalResult acc::UpdateDeviceOp::verify() {
412  // Test for all clauses this operation can be decomposed from:
413  if (getDataClause() != acc::DataClause::acc_update_device)
414  return emitError(
415  "data clause associated with device operation must match its intent"
416  " or specify original clause this operation was decomposed from");
417  return success();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // UseDeviceOp
422 //===----------------------------------------------------------------------===//
423 LogicalResult acc::UseDeviceOp::verify() {
424  // Test for all clauses this operation can be decomposed from:
425  if (getDataClause() != acc::DataClause::acc_use_device)
426  return emitError(
427  "data clause associated with use_device operation must match its intent"
428  " or specify original clause this operation was decomposed from");
429  return success();
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // CacheOp
434 //===----------------------------------------------------------------------===//
435 LogicalResult acc::CacheOp::verify() {
436  // Test for all clauses this operation can be decomposed from:
437  if (getDataClause() != acc::DataClause::acc_cache &&
438  getDataClause() != acc::DataClause::acc_cache_readonly)
439  return emitError(
440  "data clause associated with cache operation must match its intent"
441  " or specify original clause this operation was decomposed from");
442  return success();
443 }
444 
445 template <typename StructureOp>
446 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
447  unsigned nRegions = 1) {
448 
449  SmallVector<Region *, 2> regions;
450  for (unsigned i = 0; i < nRegions; ++i)
451  regions.push_back(state.addRegion());
452 
453  for (Region *region : regions)
454  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
455  return failure();
456 
457  return success();
458 }
459 
460 static bool isComputeOperation(Operation *op) {
461  return isa<acc::ParallelOp, acc::LoopOp>(op);
462 }
463 
464 namespace {
465 /// Pattern to remove operation without region that have constant false `ifCond`
466 /// and remove the condition from the operation if the `ifCond` is a true
467 /// constant.
468 template <typename OpTy>
469 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
471 
472  LogicalResult matchAndRewrite(OpTy op,
473  PatternRewriter &rewriter) const override {
474  // Early return if there is no condition.
475  Value ifCond = op.getIfCond();
476  if (!ifCond)
477  return failure();
478 
479  IntegerAttr constAttr;
480  if (!matchPattern(ifCond, m_Constant(&constAttr)))
481  return failure();
482  if (constAttr.getInt())
483  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
484  else
485  rewriter.eraseOp(op);
486 
487  return success();
488  }
489 };
490 
491 /// Replaces the given op with the contents of the given single-block region,
492 /// using the operands of the block terminator to replace operation results.
493 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
494  Region &region, ValueRange blockArgs = {}) {
495  assert(llvm::hasSingleElement(region) && "expected single-region block");
496  Block *block = &region.front();
497  Operation *terminator = block->getTerminator();
498  ValueRange results = terminator->getOperands();
499  rewriter.inlineBlockBefore(block, op, blockArgs);
500  rewriter.replaceOp(op, results);
501  rewriter.eraseOp(terminator);
502 }
503 
504 /// Pattern to remove operation with region that have constant false `ifCond`
505 /// and remove the condition from the operation if the `ifCond` is constant
506 /// true.
507 template <typename OpTy>
508 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
510 
511  LogicalResult matchAndRewrite(OpTy op,
512  PatternRewriter &rewriter) const override {
513  // Early return if there is no condition.
514  Value ifCond = op.getIfCond();
515  if (!ifCond)
516  return failure();
517 
518  IntegerAttr constAttr;
519  if (!matchPattern(ifCond, m_Constant(&constAttr)))
520  return failure();
521  if (constAttr.getInt())
522  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
523  else
524  replaceOpWithRegion(rewriter, op, op.getRegion());
525 
526  return success();
527  }
528 };
529 
530 } // namespace
531 
532 //===----------------------------------------------------------------------===//
533 // PrivateRecipeOp
534 //===----------------------------------------------------------------------===//
535 
536 static LogicalResult verifyInitLikeSingleArgRegion(
537  Operation *op, Region &region, StringRef regionType, StringRef regionName,
538  Type type, bool verifyYield, bool optional = false) {
539  if (optional && region.empty())
540  return success();
541 
542  if (region.empty())
543  return op->emitOpError() << "expects non-empty " << regionName << " region";
544  Block &firstBlock = region.front();
545  if (firstBlock.getNumArguments() < 1 ||
546  firstBlock.getArgument(0).getType() != type)
547  return op->emitOpError() << "expects " << regionName
548  << " region first "
549  "argument of the "
550  << regionType << " type";
551 
552  if (verifyYield) {
553  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
554  if (yieldOp.getOperands().size() != 1 ||
555  yieldOp.getOperands().getTypes()[0] != type)
556  return op->emitOpError() << "expects " << regionName
557  << " region to "
558  "yield a value of the "
559  << regionType << " type";
560  }
561  }
562  return success();
563 }
564 
565 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
566  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
567  "privatization", "init", getType(),
568  /*verifyYield=*/false)))
569  return failure();
571  *this, getDestroyRegion(), "privatization", "destroy", getType(),
572  /*verifyYield=*/false, /*optional=*/true)))
573  return failure();
574  return success();
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // FirstprivateRecipeOp
579 //===----------------------------------------------------------------------===//
580 
581 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
582  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
583  "privatization", "init", getType(),
584  /*verifyYield=*/false)))
585  return failure();
586 
587  if (getCopyRegion().empty())
588  return emitOpError() << "expects non-empty copy region";
589 
590  Block &firstBlock = getCopyRegion().front();
591  if (firstBlock.getNumArguments() < 2 ||
592  firstBlock.getArgument(0).getType() != getType())
593  return emitOpError() << "expects copy region with two arguments of the "
594  "privatization type";
595 
596  if (getDestroyRegion().empty())
597  return success();
598 
599  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
600  "privatization", "destroy",
601  getType(), /*verifyYield=*/false)))
602  return failure();
603 
604  return success();
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // ReductionRecipeOp
609 //===----------------------------------------------------------------------===//
610 
611 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
612  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
613  "init", getType(),
614  /*verifyYield=*/false)))
615  return failure();
616 
617  if (getCombinerRegion().empty())
618  return emitOpError() << "expects non-empty combiner region";
619 
620  Block &reductionBlock = getCombinerRegion().front();
621  if (reductionBlock.getNumArguments() < 2 ||
622  reductionBlock.getArgument(0).getType() != getType() ||
623  reductionBlock.getArgument(1).getType() != getType())
624  return emitOpError() << "expects combiner region with the first two "
625  << "arguments of the reduction type";
626 
627  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
628  if (yieldOp.getOperands().size() != 1 ||
629  yieldOp.getOperands().getTypes()[0] != getType())
630  return emitOpError() << "expects combiner region to yield a value "
631  "of the reduction type";
632  }
633 
634  return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Custom parser and printer verifier for private clause
639 //===----------------------------------------------------------------------===//
640 
641 static ParseResult parseSymOperandList(
642  mlir::OpAsmParser &parser,
644  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
646  if (failed(parser.parseCommaSeparatedList([&]() {
647  if (parser.parseAttribute(attributes.emplace_back()) ||
648  parser.parseArrow() ||
649  parser.parseOperand(operands.emplace_back()) ||
650  parser.parseColonType(types.emplace_back()))
651  return failure();
652  return success();
653  })))
654  return failure();
655  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
656  attributes.end());
657  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
658  return success();
659 }
660 
662  mlir::OperandRange operands,
663  mlir::TypeRange types,
664  std::optional<mlir::ArrayAttr> attributes) {
665  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
666  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
667  << std::get<1>(it).getType();
668  });
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // ParallelOp
673 //===----------------------------------------------------------------------===//
674 
675 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
676 template <typename Op>
677 static LogicalResult checkDataOperands(Op op,
678  const mlir::ValueRange &operands) {
679  for (mlir::Value operand : operands)
680  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
681  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
682  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
683  operand.getDefiningOp()))
684  return op.emitError(
685  "expect data entry/exit operation or acc.getdeviceptr "
686  "as defining op");
687  return success();
688 }
689 
690 template <typename Op>
691 static LogicalResult
692 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
693  mlir::OperandRange operands, llvm::StringRef operandName,
694  llvm::StringRef symbolName, bool checkOperandType = true) {
695  if (!operands.empty()) {
696  if (!attributes || attributes->size() != operands.size())
697  return op->emitOpError()
698  << "expected as many " << symbolName << " symbol reference as "
699  << operandName << " operands";
700  } else {
701  if (attributes)
702  return op->emitOpError()
703  << "unexpected " << symbolName << " symbol reference";
704  return success();
705  }
706 
708  for (auto args : llvm::zip(operands, *attributes)) {
709  mlir::Value operand = std::get<0>(args);
710 
711  if (!set.insert(operand).second)
712  return op->emitOpError()
713  << operandName << " operand appears more than once";
714 
715  mlir::Type varType = operand.getType();
716  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
717  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
718  if (!decl)
719  return op->emitOpError()
720  << "expected symbol reference " << symbolRef << " to point to a "
721  << operandName << " declaration";
722 
723  if (checkOperandType && decl.getType() && decl.getType() != varType)
724  return op->emitOpError() << "expected " << operandName << " (" << varType
725  << ") to be the same type as " << operandName
726  << " declaration (" << decl.getType() << ")";
727  }
728 
729  return success();
730 }
731 
732 unsigned ParallelOp::getNumDataOperands() {
733  return getReductionOperands().size() + getPrivateOperands().size() +
734  getFirstprivateOperands().size() + getDataClauseOperands().size();
735 }
736 
737 Value ParallelOp::getDataOperand(unsigned i) {
738  unsigned numOptional = getAsyncOperands().size();
739  numOptional += getNumGangs().size();
740  numOptional += getNumWorkers().size();
741  numOptional += getVectorLength().size();
742  numOptional += getIfCond() ? 1 : 0;
743  numOptional += getSelfCond() ? 1 : 0;
744  return getOperand(getWaitOperands().size() + numOptional + i);
745 }
746 
747 template <typename Op>
748 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
749  ArrayAttr deviceTypes,
750  llvm::StringRef keyword) {
751  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
752  return op.emitOpError() << keyword << " operands count must match "
753  << keyword << " device_type count";
754  return success();
755 }
756 
757 template <typename Op>
759  Op op, OperandRange operands, DenseI32ArrayAttr segments,
760  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
761  std::size_t numOperandsInSegments = 0;
762  std::size_t nbOfSegments = 0;
763 
764  if (segments) {
765  for (auto segCount : segments.asArrayRef()) {
766  if (maxInSegment != 0 && segCount > maxInSegment)
767  return op.emitOpError() << keyword << " expects a maximum of "
768  << maxInSegment << " values per segment";
769  numOperandsInSegments += segCount;
770  ++nbOfSegments;
771  }
772  }
773 
774  if ((numOperandsInSegments != operands.size()) ||
775  (!deviceTypes && !operands.empty()))
776  return op.emitOpError()
777  << keyword << " operand count does not match count in segments";
778  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
779  return op.emitOpError()
780  << keyword << " segment count does not match device_type count";
781  return success();
782 }
783 
784 LogicalResult acc::ParallelOp::verify() {
785  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
786  *this, getPrivatizations(), getPrivateOperands(), "private",
787  "privatizations", /*checkOperandType=*/false)))
788  return failure();
789  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
790  *this, getFirstprivatizations(), getFirstprivateOperands(),
791  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
792  return failure();
793  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
794  *this, getReductionRecipes(), getReductionOperands(), "reduction",
795  "reductions", false)))
796  return failure();
797 
799  *this, getNumGangs(), getNumGangsSegmentsAttr(),
800  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
801  return failure();
802 
804  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
805  getWaitOperandsDeviceTypeAttr(), "wait")))
806  return failure();
807 
808  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
809  getNumWorkersDeviceTypeAttr(),
810  "num_workers")))
811  return failure();
812 
813  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
814  getVectorLengthDeviceTypeAttr(),
815  "vector_length")))
816  return failure();
817 
818  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
819  getAsyncOperandsDeviceTypeAttr(),
820  "async")))
821  return failure();
822 
823  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
824  return failure();
825 
826  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
827 }
828 
829 static mlir::Value
830 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
832  mlir::acc::DeviceType deviceType) {
833  if (!arrayAttr)
834  return {};
835  if (auto pos = findSegment(*arrayAttr, deviceType))
836  return range[*pos];
837  return {};
838 }
839 
840 bool acc::ParallelOp::hasAsyncOnly() {
841  return hasAsyncOnly(mlir::acc::DeviceType::None);
842 }
843 
844 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
845  return hasDeviceType(getAsyncOnly(), deviceType);
846 }
847 
848 mlir::Value acc::ParallelOp::getAsyncValue() {
849  return getAsyncValue(mlir::acc::DeviceType::None);
850 }
851 
852 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
854  getAsyncOperands(), deviceType);
855 }
856 
857 mlir::Value acc::ParallelOp::getNumWorkersValue() {
858  return getNumWorkersValue(mlir::acc::DeviceType::None);
859 }
860 
862 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
863  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
864  deviceType);
865 }
866 
867 mlir::Value acc::ParallelOp::getVectorLengthValue() {
868  return getVectorLengthValue(mlir::acc::DeviceType::None);
869 }
870 
872 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
873  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
874  getVectorLength(), deviceType);
875 }
876 
877 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
878  return getNumGangsValues(mlir::acc::DeviceType::None);
879 }
880 
882 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
883  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
884  getNumGangsSegments(), deviceType);
885 }
886 
887 bool acc::ParallelOp::hasWaitOnly() {
888  return hasWaitOnly(mlir::acc::DeviceType::None);
889 }
890 
891 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
892  return hasDeviceType(getWaitOnly(), deviceType);
893 }
894 
895 mlir::Operation::operand_range ParallelOp::getWaitValues() {
896  return getWaitValues(mlir::acc::DeviceType::None);
897 }
898 
900 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
902  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
903  getHasWaitDevnum(), deviceType);
904 }
905 
906 mlir::Value ParallelOp::getWaitDevnum() {
907  return getWaitDevnum(mlir::acc::DeviceType::None);
908 }
909 
910 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
911  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
912  getWaitOperandsSegments(), getHasWaitDevnum(),
913  deviceType);
914 }
915 
916 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
917  mlir::OperationState &odsState,
918  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
919  mlir::ValueRange vectorLength,
920  mlir::ValueRange asyncOperands,
921  mlir::ValueRange waitOperands, mlir::Value ifCond,
922  mlir::Value selfCond, mlir::ValueRange reductionOperands,
923  mlir::ValueRange gangPrivateOperands,
924  mlir::ValueRange gangFirstPrivateOperands,
925  mlir::ValueRange dataClauseOperands) {
926 
927  ParallelOp::build(
928  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
929  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
930  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
931  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
932  /*numGangsDeviceType=*/nullptr, numWorkers,
933  /*numWorkersDeviceType=*/nullptr, vectorLength,
934  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
935  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
936  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
937  /*firstprivatizations=*/nullptr, dataClauseOperands,
938  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
939 }
940 
941 static ParseResult parseNumGangs(
942  mlir::OpAsmParser &parser,
944  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
945  mlir::DenseI32ArrayAttr &segments) {
948 
949  do {
950  if (failed(parser.parseLBrace()))
951  return failure();
952 
953  int32_t crtOperandsSize = operands.size();
954  if (failed(parser.parseCommaSeparatedList(
956  if (parser.parseOperand(operands.emplace_back()) ||
957  parser.parseColonType(types.emplace_back()))
958  return failure();
959  return success();
960  })))
961  return failure();
962  seg.push_back(operands.size() - crtOperandsSize);
963 
964  if (failed(parser.parseRBrace()))
965  return failure();
966 
967  if (succeeded(parser.parseOptionalLSquare())) {
968  if (parser.parseAttribute(attributes.emplace_back()) ||
969  parser.parseRSquare())
970  return failure();
971  } else {
972  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
974  }
975  } while (succeeded(parser.parseOptionalComma()));
976 
977  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
978  attributes.end());
979  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
980  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
981 
982  return success();
983 }
984 
986  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
987  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
988  p << " [" << attr << "]";
989 }
990 
992  mlir::OperandRange operands, mlir::TypeRange types,
993  std::optional<mlir::ArrayAttr> deviceTypes,
994  std::optional<mlir::DenseI32ArrayAttr> segments) {
995  unsigned opIdx = 0;
996  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
997  p << "{";
998  llvm::interleaveComma(
999  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1000  p << operands[opIdx] << " : " << operands[opIdx].getType();
1001  ++opIdx;
1002  });
1003  p << "}";
1004  printSingleDeviceType(p, it.value());
1005  });
1006 }
1007 
1009  mlir::OpAsmParser &parser,
1011  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1012  mlir::DenseI32ArrayAttr &segments) {
1015 
1016  do {
1017  if (failed(parser.parseLBrace()))
1018  return failure();
1019 
1020  int32_t crtOperandsSize = operands.size();
1021 
1022  if (failed(parser.parseCommaSeparatedList(
1024  if (parser.parseOperand(operands.emplace_back()) ||
1025  parser.parseColonType(types.emplace_back()))
1026  return failure();
1027  return success();
1028  })))
1029  return failure();
1030 
1031  seg.push_back(operands.size() - crtOperandsSize);
1032 
1033  if (failed(parser.parseRBrace()))
1034  return failure();
1035 
1036  if (succeeded(parser.parseOptionalLSquare())) {
1037  if (parser.parseAttribute(attributes.emplace_back()) ||
1038  parser.parseRSquare())
1039  return failure();
1040  } else {
1041  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1043  }
1044  } while (succeeded(parser.parseOptionalComma()));
1045 
1046  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1047  attributes.end());
1048  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1049  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1050 
1051  return success();
1052 }
1053 
1056  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1057  std::optional<mlir::DenseI32ArrayAttr> segments) {
1058  unsigned opIdx = 0;
1059  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1060  p << "{";
1061  llvm::interleaveComma(
1062  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1063  p << operands[opIdx] << " : " << operands[opIdx].getType();
1064  ++opIdx;
1065  });
1066  p << "}";
1067  printSingleDeviceType(p, it.value());
1068  });
1069 }
1070 
1071 static ParseResult parseWaitClause(
1072  mlir::OpAsmParser &parser,
1074  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1075  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1076  mlir::ArrayAttr &keywordOnly) {
1077  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1079 
1080  bool needCommaBeforeOperands = false;
1081 
1082  // Keyword only
1083  if (failed(parser.parseOptionalLParen())) {
1084  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1086  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1087  return success();
1088  }
1089 
1090  // Parse keyword only attributes
1091  if (succeeded(parser.parseOptionalLSquare())) {
1092  if (failed(parser.parseCommaSeparatedList([&]() {
1093  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1094  return failure();
1095  return success();
1096  })))
1097  return failure();
1098  if (parser.parseRSquare())
1099  return failure();
1100  needCommaBeforeOperands = true;
1101  }
1102 
1103  if (needCommaBeforeOperands && failed(parser.parseComma()))
1104  return failure();
1105 
1106  do {
1107  if (failed(parser.parseLBrace()))
1108  return failure();
1109 
1110  int32_t crtOperandsSize = operands.size();
1111 
1112  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1113  if (failed(parser.parseColon()))
1114  return failure();
1115  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1116  } else {
1117  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1118  }
1119 
1120  if (failed(parser.parseCommaSeparatedList(
1122  if (parser.parseOperand(operands.emplace_back()) ||
1123  parser.parseColonType(types.emplace_back()))
1124  return failure();
1125  return success();
1126  })))
1127  return failure();
1128 
1129  seg.push_back(operands.size() - crtOperandsSize);
1130 
1131  if (failed(parser.parseRBrace()))
1132  return failure();
1133 
1134  if (succeeded(parser.parseOptionalLSquare())) {
1135  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1136  parser.parseRSquare())
1137  return failure();
1138  } else {
1139  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1141  }
1142  } while (succeeded(parser.parseOptionalComma()));
1143 
1144  if (failed(parser.parseRParen()))
1145  return failure();
1146 
1147  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1148  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1149  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1150  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1151 
1152  return success();
1153 }
1154 
1155 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1156  if (!hasDeviceTypeValues(attrs))
1157  return false;
1158  if (attrs->size() != 1)
1159  return false;
1160  if (auto deviceTypeAttr =
1161  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1162  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1163  return false;
1164 }
1165 
1167  mlir::OperandRange operands, mlir::TypeRange types,
1168  std::optional<mlir::ArrayAttr> deviceTypes,
1169  std::optional<mlir::DenseI32ArrayAttr> segments,
1170  std::optional<mlir::ArrayAttr> hasDevNum,
1171  std::optional<mlir::ArrayAttr> keywordOnly) {
1172 
1173  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1174  return;
1175 
1176  p << "(";
1177 
1178  printDeviceTypes(p, keywordOnly);
1179  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1180  p << ", ";
1181 
1182  unsigned opIdx = 0;
1183  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1184  p << "{";
1185  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1186  if (boolAttr && boolAttr.getValue())
1187  p << "devnum: ";
1188  llvm::interleaveComma(
1189  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1190  p << operands[opIdx] << " : " << operands[opIdx].getType();
1191  ++opIdx;
1192  });
1193  p << "}";
1194  printSingleDeviceType(p, it.value());
1195  });
1196 
1197  p << ")";
1198 }
1199 
1200 static ParseResult parseDeviceTypeOperands(
1201  mlir::OpAsmParser &parser,
1203  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1205  if (failed(parser.parseCommaSeparatedList([&]() {
1206  if (parser.parseOperand(operands.emplace_back()) ||
1207  parser.parseColonType(types.emplace_back()))
1208  return failure();
1209  if (succeeded(parser.parseOptionalLSquare())) {
1210  if (parser.parseAttribute(attributes.emplace_back()) ||
1211  parser.parseRSquare())
1212  return failure();
1213  } else {
1214  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1215  parser.getContext(), mlir::acc::DeviceType::None));
1216  }
1217  return success();
1218  })))
1219  return failure();
1220  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1221  attributes.end());
1222  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1223  return success();
1224 }
1225 
1226 static void
1228  mlir::OperandRange operands, mlir::TypeRange types,
1229  std::optional<mlir::ArrayAttr> deviceTypes) {
1230  if (!hasDeviceTypeValues(deviceTypes))
1231  return;
1232  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1233  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1234  printSingleDeviceType(p, std::get<0>(it));
1235  });
1236 }
1237 
1239  mlir::OpAsmParser &parser,
1241  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1242  mlir::ArrayAttr &keywordOnlyDeviceType) {
1243 
1244  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1245  bool needCommaBeforeOperands = false;
1246 
1247  if (failed(parser.parseOptionalLParen())) {
1248  // Keyword only
1249  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1251  keywordOnlyDeviceType =
1252  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1253  return success();
1254  }
1255 
1256  // Parse keyword only attributes
1257  if (succeeded(parser.parseOptionalLSquare())) {
1258  // Parse keyword only attributes
1259  if (failed(parser.parseCommaSeparatedList([&]() {
1260  if (parser.parseAttribute(
1261  keywordOnlyDeviceTypeAttributes.emplace_back()))
1262  return failure();
1263  return success();
1264  })))
1265  return failure();
1266  if (parser.parseRSquare())
1267  return failure();
1268  needCommaBeforeOperands = true;
1269  }
1270 
1271  if (needCommaBeforeOperands && failed(parser.parseComma()))
1272  return failure();
1273 
1275  if (failed(parser.parseCommaSeparatedList([&]() {
1276  if (parser.parseOperand(operands.emplace_back()) ||
1277  parser.parseColonType(types.emplace_back()))
1278  return failure();
1279  if (succeeded(parser.parseOptionalLSquare())) {
1280  if (parser.parseAttribute(attributes.emplace_back()) ||
1281  parser.parseRSquare())
1282  return failure();
1283  } else {
1284  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1285  parser.getContext(), mlir::acc::DeviceType::None));
1286  }
1287  return success();
1288  })))
1289  return failure();
1290 
1291  if (failed(parser.parseRParen()))
1292  return failure();
1293 
1294  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1295  attributes.end());
1296  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1297  return success();
1298 }
1299 
1302  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1303  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1304 
1305  if (operands.begin() == operands.end() &&
1306  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1307  return;
1308  }
1309 
1310  p << "(";
1311  printDeviceTypes(p, keywordOnlyDeviceTypes);
1312  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1313  hasDeviceTypeValues(deviceTypes))
1314  p << ", ";
1315  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1316  p << ")";
1317 }
1318 
1319 static ParseResult
1321  mlir::acc::CombinedConstructsTypeAttr &attr) {
1322  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1323  if (parser.parseLParen())
1324  return failure();
1325  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1327  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1328  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1330  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1331  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1333  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1334  } else {
1335  parser.emitError(parser.getCurrentLocation(),
1336  "expected compute construct name");
1337  return failure();
1338  }
1339  if (parser.parseRParen())
1340  return failure();
1341  }
1342  return success();
1343 }
1344 
1345 static void
1347  mlir::acc::CombinedConstructsTypeAttr attr) {
1348  if (attr) {
1349  switch (attr.getValue()) {
1350  case mlir::acc::CombinedConstructsType::KernelsLoop:
1351  p << "combined(kernels)";
1352  break;
1353  case mlir::acc::CombinedConstructsType::ParallelLoop:
1354  p << "combined(parallel)";
1355  break;
1356  case mlir::acc::CombinedConstructsType::SerialLoop:
1357  p << "combined(serial)";
1358  break;
1359  };
1360  }
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // SerialOp
1365 //===----------------------------------------------------------------------===//
1366 
1367 unsigned SerialOp::getNumDataOperands() {
1368  return getReductionOperands().size() + getPrivateOperands().size() +
1369  getFirstprivateOperands().size() + getDataClauseOperands().size();
1370 }
1371 
1372 Value SerialOp::getDataOperand(unsigned i) {
1373  unsigned numOptional = getAsyncOperands().size();
1374  numOptional += getIfCond() ? 1 : 0;
1375  numOptional += getSelfCond() ? 1 : 0;
1376  return getOperand(getWaitOperands().size() + numOptional + i);
1377 }
1378 
1379 bool acc::SerialOp::hasAsyncOnly() {
1380  return hasAsyncOnly(mlir::acc::DeviceType::None);
1381 }
1382 
1383 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1384  return hasDeviceType(getAsyncOnly(), deviceType);
1385 }
1386 
1387 mlir::Value acc::SerialOp::getAsyncValue() {
1388  return getAsyncValue(mlir::acc::DeviceType::None);
1389 }
1390 
1391 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1393  getAsyncOperands(), deviceType);
1394 }
1395 
1396 bool acc::SerialOp::hasWaitOnly() {
1397  return hasWaitOnly(mlir::acc::DeviceType::None);
1398 }
1399 
1400 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1401  return hasDeviceType(getWaitOnly(), deviceType);
1402 }
1403 
1404 mlir::Operation::operand_range SerialOp::getWaitValues() {
1405  return getWaitValues(mlir::acc::DeviceType::None);
1406 }
1407 
1409 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1411  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1412  getHasWaitDevnum(), deviceType);
1413 }
1414 
1415 mlir::Value SerialOp::getWaitDevnum() {
1416  return getWaitDevnum(mlir::acc::DeviceType::None);
1417 }
1418 
1419 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1420  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1421  getWaitOperandsSegments(), getHasWaitDevnum(),
1422  deviceType);
1423 }
1424 
1425 LogicalResult acc::SerialOp::verify() {
1426  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1427  *this, getPrivatizations(), getPrivateOperands(), "private",
1428  "privatizations", /*checkOperandType=*/false)))
1429  return failure();
1430  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1431  *this, getFirstprivatizations(), getFirstprivateOperands(),
1432  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1433  return failure();
1434  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1435  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1436  "reductions", false)))
1437  return failure();
1438 
1440  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1441  getWaitOperandsDeviceTypeAttr(), "wait")))
1442  return failure();
1443 
1444  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1445  getAsyncOperandsDeviceTypeAttr(),
1446  "async")))
1447  return failure();
1448 
1449  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1450  return failure();
1451 
1452  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // KernelsOp
1457 //===----------------------------------------------------------------------===//
1458 
1459 unsigned KernelsOp::getNumDataOperands() {
1460  return getDataClauseOperands().size();
1461 }
1462 
1463 Value KernelsOp::getDataOperand(unsigned i) {
1464  unsigned numOptional = getAsyncOperands().size();
1465  numOptional += getWaitOperands().size();
1466  numOptional += getNumGangs().size();
1467  numOptional += getNumWorkers().size();
1468  numOptional += getVectorLength().size();
1469  numOptional += getIfCond() ? 1 : 0;
1470  numOptional += getSelfCond() ? 1 : 0;
1471  return getOperand(numOptional + i);
1472 }
1473 
1474 bool acc::KernelsOp::hasAsyncOnly() {
1475  return hasAsyncOnly(mlir::acc::DeviceType::None);
1476 }
1477 
1478 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1479  return hasDeviceType(getAsyncOnly(), deviceType);
1480 }
1481 
1482 mlir::Value acc::KernelsOp::getAsyncValue() {
1483  return getAsyncValue(mlir::acc::DeviceType::None);
1484 }
1485 
1486 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1488  getAsyncOperands(), deviceType);
1489 }
1490 
1491 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1492  return getNumWorkersValue(mlir::acc::DeviceType::None);
1493 }
1494 
1496 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1497  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1498  deviceType);
1499 }
1500 
1501 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1502  return getVectorLengthValue(mlir::acc::DeviceType::None);
1503 }
1504 
1506 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1507  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1508  getVectorLength(), deviceType);
1509 }
1510 
1511 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1512  return getNumGangsValues(mlir::acc::DeviceType::None);
1513 }
1514 
1516 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1517  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1518  getNumGangsSegments(), deviceType);
1519 }
1520 
1521 bool acc::KernelsOp::hasWaitOnly() {
1522  return hasWaitOnly(mlir::acc::DeviceType::None);
1523 }
1524 
1525 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1526  return hasDeviceType(getWaitOnly(), deviceType);
1527 }
1528 
1529 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1530  return getWaitValues(mlir::acc::DeviceType::None);
1531 }
1532 
1534 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1536  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1537  getHasWaitDevnum(), deviceType);
1538 }
1539 
1540 mlir::Value KernelsOp::getWaitDevnum() {
1541  return getWaitDevnum(mlir::acc::DeviceType::None);
1542 }
1543 
1544 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1545  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1546  getWaitOperandsSegments(), getHasWaitDevnum(),
1547  deviceType);
1548 }
1549 
1550 LogicalResult acc::KernelsOp::verify() {
1552  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1553  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1554  return failure();
1555 
1557  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1558  getWaitOperandsDeviceTypeAttr(), "wait")))
1559  return failure();
1560 
1561  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1562  getNumWorkersDeviceTypeAttr(),
1563  "num_workers")))
1564  return failure();
1565 
1566  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1567  getVectorLengthDeviceTypeAttr(),
1568  "vector_length")))
1569  return failure();
1570 
1571  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1572  getAsyncOperandsDeviceTypeAttr(),
1573  "async")))
1574  return failure();
1575 
1576  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1577  return failure();
1578 
1579  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1580 }
1581 
1582 //===----------------------------------------------------------------------===//
1583 // HostDataOp
1584 //===----------------------------------------------------------------------===//
1585 
1586 LogicalResult acc::HostDataOp::verify() {
1587  if (getDataClauseOperands().empty())
1588  return emitError("at least one operand must appear on the host_data "
1589  "operation");
1590 
1591  for (mlir::Value operand : getDataClauseOperands())
1592  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1593  return emitError("expect data entry operation as defining op");
1594  return success();
1595 }
1596 
1597 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1598  MLIRContext *context) {
1599  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // LoopOp
1604 //===----------------------------------------------------------------------===//
1605 
1606 static ParseResult parseGangValue(
1607  OpAsmParser &parser, llvm::StringRef keyword,
1610  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1611  bool &needCommaBetweenValues, bool &newValue) {
1612  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1613  if (parser.parseEqual())
1614  return failure();
1615  if (parser.parseOperand(operands.emplace_back()) ||
1616  parser.parseColonType(types.emplace_back()))
1617  return failure();
1618  attributes.push_back(gangArgType);
1619  needCommaBetweenValues = true;
1620  newValue = true;
1621  }
1622  return success();
1623 }
1624 
1625 static ParseResult parseGangClause(
1626  OpAsmParser &parser,
1628  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1629  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1630  mlir::ArrayAttr &gangOnlyDeviceType) {
1631  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1632  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1633  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1635  bool needCommaBetweenValues = false;
1636  bool needCommaBeforeOperands = false;
1637 
1638  if (failed(parser.parseOptionalLParen())) {
1639  // Gang only keyword
1640  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1642  gangOnlyDeviceType =
1643  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1644  return success();
1645  }
1646 
1647  // Parse gang only attributes
1648  if (succeeded(parser.parseOptionalLSquare())) {
1649  // Parse gang only attributes
1650  if (failed(parser.parseCommaSeparatedList([&]() {
1651  if (parser.parseAttribute(
1652  gangOnlyDeviceTypeAttributes.emplace_back()))
1653  return failure();
1654  return success();
1655  })))
1656  return failure();
1657  if (parser.parseRSquare())
1658  return failure();
1659  needCommaBeforeOperands = true;
1660  }
1661 
1662  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1663  mlir::acc::GangArgType::Num);
1664  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1665  mlir::acc::GangArgType::Dim);
1666  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1667  parser.getContext(), mlir::acc::GangArgType::Static);
1668 
1669  do {
1670  if (needCommaBeforeOperands) {
1671  needCommaBeforeOperands = false;
1672  continue;
1673  }
1674 
1675  if (failed(parser.parseLBrace()))
1676  return failure();
1677 
1678  int32_t crtOperandsSize = gangOperands.size();
1679  while (true) {
1680  bool newValue = false;
1681  bool needValue = false;
1682  if (needCommaBetweenValues) {
1683  if (succeeded(parser.parseOptionalComma()))
1684  needValue = true; // expect a new value after comma.
1685  else
1686  break;
1687  }
1688 
1689  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1690  gangOperands, gangOperandsType,
1691  gangArgTypeAttributes, argNum,
1692  needCommaBetweenValues, newValue)))
1693  return failure();
1694  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1695  gangOperands, gangOperandsType,
1696  gangArgTypeAttributes, argDim,
1697  needCommaBetweenValues, newValue)))
1698  return failure();
1699  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1700  gangOperands, gangOperandsType,
1701  gangArgTypeAttributes, argStatic,
1702  needCommaBetweenValues, newValue)))
1703  return failure();
1704 
1705  if (!newValue && needValue) {
1706  parser.emitError(parser.getCurrentLocation(),
1707  "new value expected after comma");
1708  return failure();
1709  }
1710 
1711  if (!newValue)
1712  break;
1713  }
1714 
1715  if (gangOperands.empty())
1716  return parser.emitError(
1717  parser.getCurrentLocation(),
1718  "expect at least one of num, dim or static values");
1719 
1720  if (failed(parser.parseRBrace()))
1721  return failure();
1722 
1723  if (succeeded(parser.parseOptionalLSquare())) {
1724  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1725  parser.parseRSquare())
1726  return failure();
1727  } else {
1728  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1730  }
1731 
1732  seg.push_back(gangOperands.size() - crtOperandsSize);
1733 
1734  } while (succeeded(parser.parseOptionalComma()));
1735 
1736  if (failed(parser.parseRParen()))
1737  return failure();
1738 
1739  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1740  gangArgTypeAttributes.end());
1741  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1742  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1743 
1745  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1746  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1747 
1748  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1749  return success();
1750 }
1751 
1753  mlir::OperandRange operands, mlir::TypeRange types,
1754  std::optional<mlir::ArrayAttr> gangArgTypes,
1755  std::optional<mlir::ArrayAttr> deviceTypes,
1756  std::optional<mlir::DenseI32ArrayAttr> segments,
1757  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1758 
1759  if (operands.begin() == operands.end() &&
1760  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1761  return;
1762  }
1763 
1764  p << "(";
1765 
1766  printDeviceTypes(p, gangOnlyDeviceTypes);
1767 
1768  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1769  hasDeviceTypeValues(deviceTypes))
1770  p << ", ";
1771 
1772  if (hasDeviceTypeValues(deviceTypes)) {
1773  unsigned opIdx = 0;
1774  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1775  p << "{";
1776  llvm::interleaveComma(
1777  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1778  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1779  (*gangArgTypes)[opIdx]);
1780  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1781  p << LoopOp::getGangNumKeyword();
1782  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1783  p << LoopOp::getGangDimKeyword();
1784  else if (gangArgTypeAttr.getValue() ==
1785  mlir::acc::GangArgType::Static)
1786  p << LoopOp::getGangStaticKeyword();
1787  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1788  ++opIdx;
1789  });
1790  p << "}";
1791  printSingleDeviceType(p, it.value());
1792  });
1793  }
1794  p << ")";
1795 }
1796 
1798  std::optional<mlir::ArrayAttr> segments,
1799  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1800  if (!segments)
1801  return false;
1802  for (auto attr : *segments) {
1803  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1804  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
1805  return true;
1806  }
1807  return false;
1808 }
1809 
1810 /// Check for duplicates in the DeviceType array attribute.
1811 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1812  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1813  if (!deviceTypes)
1814  return success();
1815  for (auto attr : deviceTypes) {
1816  auto deviceTypeAttr =
1817  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1818  if (!deviceTypeAttr)
1819  return failure();
1820  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
1821  return failure();
1822  }
1823  return success();
1824 }
1825 
1826 LogicalResult acc::LoopOp::verify() {
1827  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1828  (getUpperbound().size() != getInclusiveUpperbound()->size()))
1829  return emitError() << "inclusiveUpperbound size is expected to be the same"
1830  << " as upperbound size";
1831 
1832  // Check collapse
1833  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1834  return emitOpError() << "collapse device_type attr must be define when"
1835  << " collapse attr is present";
1836 
1837  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1838  getCollapseAttr().getValue().size() !=
1839  getCollapseDeviceTypeAttr().getValue().size())
1840  return emitOpError() << "collapse attribute count must match collapse"
1841  << " device_type count";
1842  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1843  return emitOpError()
1844  << "duplicate device_type found in collapseDeviceType attribute";
1845 
1846  // Check gang
1847  if (!getGangOperands().empty()) {
1848  if (!getGangOperandsArgType())
1849  return emitOpError() << "gangOperandsArgType attribute must be defined"
1850  << " when gang operands are present";
1851 
1852  if (getGangOperands().size() !=
1853  getGangOperandsArgTypeAttr().getValue().size())
1854  return emitOpError() << "gangOperandsArgType attribute count must match"
1855  << " gangOperands count";
1856  }
1857  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1858  return emitOpError() << "duplicate device_type found in gang attribute";
1859 
1861  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1862  getGangOperandsDeviceTypeAttr(), "gang")))
1863  return failure();
1864 
1865  // Check worker
1866  if (failed(checkDeviceTypes(getWorkerAttr())))
1867  return emitOpError() << "duplicate device_type found in worker attribute";
1868  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1869  return emitOpError() << "duplicate device_type found in "
1870  "workerNumOperandsDeviceType attribute";
1871  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1872  getWorkerNumOperandsDeviceTypeAttr(),
1873  "worker")))
1874  return failure();
1875 
1876  // Check vector
1877  if (failed(checkDeviceTypes(getVectorAttr())))
1878  return emitOpError() << "duplicate device_type found in vector attribute";
1879  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1880  return emitOpError() << "duplicate device_type found in "
1881  "vectorOperandsDeviceType attribute";
1882  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1883  getVectorOperandsDeviceTypeAttr(),
1884  "vector")))
1885  return failure();
1886 
1888  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1889  getTileOperandsDeviceTypeAttr(), "tile")))
1890  return failure();
1891 
1892  // auto, independent and seq attribute are mutually exclusive.
1893  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1894  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1895  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1896  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1897  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1898  << "\", " << getIndependentAttrName() << ", "
1899  << getSeqAttrName()
1900  << " can be present at the same time";
1901  }
1902 
1903  // Gang, worker and vector are incompatible with seq.
1904  if (getSeqAttr()) {
1905  for (auto attr : getSeqAttr()) {
1906  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1907  if (hasVector(deviceTypeAttr.getValue()) ||
1908  getVectorValue(deviceTypeAttr.getValue()) ||
1909  hasWorker(deviceTypeAttr.getValue()) ||
1910  getWorkerValue(deviceTypeAttr.getValue()) ||
1911  hasGang(deviceTypeAttr.getValue()) ||
1912  getGangValue(mlir::acc::GangArgType::Num,
1913  deviceTypeAttr.getValue()) ||
1914  getGangValue(mlir::acc::GangArgType::Dim,
1915  deviceTypeAttr.getValue()) ||
1916  getGangValue(mlir::acc::GangArgType::Static,
1917  deviceTypeAttr.getValue()))
1918  return emitError()
1919  << "gang, worker or vector cannot appear with the seq attr";
1920  }
1921  }
1922 
1923  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1924  *this, getPrivatizations(), getPrivateOperands(), "private",
1925  "privatizations", false)))
1926  return failure();
1927 
1928  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1929  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1930  "reductions", false)))
1931  return failure();
1932 
1933  if (getCombined().has_value() &&
1934  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1935  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1936  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1937  return emitError("unexpected combined constructs attribute");
1938  }
1939 
1940  // Check non-empty body().
1941  if (getRegion().empty())
1942  return emitError("expected non-empty body.");
1943 
1944  return success();
1945 }
1946 
1947 unsigned LoopOp::getNumDataOperands() {
1948  return getReductionOperands().size() + getPrivateOperands().size();
1949 }
1950 
1951 Value LoopOp::getDataOperand(unsigned i) {
1952  unsigned numOptional =
1953  getLowerbound().size() + getUpperbound().size() + getStep().size();
1954  numOptional += getGangOperands().size();
1955  numOptional += getVectorOperands().size();
1956  numOptional += getWorkerNumOperands().size();
1957  numOptional += getTileOperands().size();
1958  numOptional += getCacheOperands().size();
1959  return getOperand(numOptional + i);
1960 }
1961 
1962 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
1963 
1964 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1965  return hasDeviceType(getAuto_(), deviceType);
1966 }
1967 
1968 bool LoopOp::hasIndependent() {
1969  return hasIndependent(mlir::acc::DeviceType::None);
1970 }
1971 
1972 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1973  return hasDeviceType(getIndependent(), deviceType);
1974 }
1975 
1976 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
1977 
1978 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1979  return hasDeviceType(getSeq(), deviceType);
1980 }
1981 
1982 mlir::Value LoopOp::getVectorValue() {
1983  return getVectorValue(mlir::acc::DeviceType::None);
1984 }
1985 
1986 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1987  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
1988  getVectorOperands(), deviceType);
1989 }
1990 
1991 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
1992 
1993 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1994  return hasDeviceType(getVector(), deviceType);
1995 }
1996 
1997 mlir::Value LoopOp::getWorkerValue() {
1998  return getWorkerValue(mlir::acc::DeviceType::None);
1999 }
2000 
2001 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2002  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2003  getWorkerNumOperands(), deviceType);
2004 }
2005 
2006 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2007 
2008 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2009  return hasDeviceType(getWorker(), deviceType);
2010 }
2011 
2012 mlir::Operation::operand_range LoopOp::getTileValues() {
2013  return getTileValues(mlir::acc::DeviceType::None);
2014 }
2015 
2017 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2018  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2019  getTileOperandsSegments(), deviceType);
2020 }
2021 
2022 std::optional<int64_t> LoopOp::getCollapseValue() {
2023  return getCollapseValue(mlir::acc::DeviceType::None);
2024 }
2025 
2026 std::optional<int64_t>
2027 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2028  if (!getCollapseAttr())
2029  return std::nullopt;
2030  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2031  auto intAttr =
2032  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2033  return intAttr.getValue().getZExtValue();
2034  }
2035  return std::nullopt;
2036 }
2037 
2038 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2039  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2040 }
2041 
2042 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2043  mlir::acc::DeviceType deviceType) {
2044  if (getGangOperands().empty())
2045  return {};
2046  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2047  int32_t nbOperandsBefore = 0;
2048  for (unsigned i = 0; i < *pos; ++i)
2049  nbOperandsBefore += (*getGangOperandsSegments())[i];
2051  getGangOperands()
2052  .drop_front(nbOperandsBefore)
2053  .take_front((*getGangOperandsSegments())[*pos]);
2054 
2055  int32_t argTypeIdx = nbOperandsBefore;
2056  for (auto value : values) {
2057  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2058  (*getGangOperandsArgType())[argTypeIdx]);
2059  if (gangArgTypeAttr.getValue() == gangArgType)
2060  return value;
2061  ++argTypeIdx;
2062  }
2063  }
2064  return {};
2065 }
2066 
2067 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2068 
2069 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2070  return hasDeviceType(getGang(), deviceType);
2071 }
2072 
2073 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2074  return {&getRegion()};
2075 }
2076 
2077 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2078 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2079 /// `(` ssa-id-and-type-list `)`
2080 /// region
2081 ParseResult
2084  SmallVectorImpl<Type> &lowerboundType,
2086  SmallVectorImpl<Type> &upperboundType,
2088  SmallVectorImpl<Type> &stepType) {
2089 
2090  SmallVector<OpAsmParser::Argument> inductionVars;
2091  if (succeeded(
2092  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2093  if (parser.parseLParen() ||
2094  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2095  /*allowType=*/true) ||
2096  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2097  parser.parseOperandList(lowerbound, inductionVars.size(),
2099  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2100  parser.parseKeyword("to") || parser.parseLParen() ||
2101  parser.parseOperandList(upperbound, inductionVars.size(),
2103  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2104  parser.parseKeyword("step") || parser.parseLParen() ||
2105  parser.parseOperandList(step, inductionVars.size(),
2107  parser.parseColonTypeList(stepType) || parser.parseRParen())
2108  return failure();
2109  }
2110  return parser.parseRegion(region, inductionVars);
2111 }
2112 
2114  ValueRange lowerbound, TypeRange lowerboundType,
2115  ValueRange upperbound, TypeRange upperboundType,
2116  ValueRange steps, TypeRange stepType) {
2117  ValueRange regionArgs = region.front().getArguments();
2118  if (!regionArgs.empty()) {
2119  p << acc::LoopOp::getControlKeyword() << "(";
2120  llvm::interleaveComma(regionArgs, p,
2121  [&p](Value v) { p << v << " : " << v.getType(); });
2122  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2123  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2124  << " : " << stepType << ") ";
2125  }
2126  p.printRegion(region, /*printEntryBlockArgs=*/false);
2127 }
2128 
2129 //===----------------------------------------------------------------------===//
2130 // DataOp
2131 //===----------------------------------------------------------------------===//
2132 
2133 LogicalResult acc::DataOp::verify() {
2134  // 2.6.5. Data Construct restriction
2135  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2136  // attach, or default clause must appear on a data construct.
2137  if (getOperands().empty() && !getDefaultAttr())
2138  return emitError("at least one operand or the default attribute "
2139  "must appear on the data operation");
2140 
2141  for (mlir::Value operand : getDataClauseOperands())
2142  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2143  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2144  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2145  operand.getDefiningOp()))
2146  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2147  "as defining op");
2148 
2149  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2150  return failure();
2151 
2152  return success();
2153 }
2154 
2155 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2156 
2157 Value DataOp::getDataOperand(unsigned i) {
2158  unsigned numOptional = getIfCond() ? 1 : 0;
2159  numOptional += getAsyncOperands().size() ? 1 : 0;
2160  numOptional += getWaitOperands().size();
2161  return getOperand(numOptional + i);
2162 }
2163 
2164 bool acc::DataOp::hasAsyncOnly() {
2165  return hasAsyncOnly(mlir::acc::DeviceType::None);
2166 }
2167 
2168 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2169  return hasDeviceType(getAsyncOnly(), deviceType);
2170 }
2171 
2172 mlir::Value DataOp::getAsyncValue() {
2173  return getAsyncValue(mlir::acc::DeviceType::None);
2174 }
2175 
2176 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2178  getAsyncOperands(), deviceType);
2179 }
2180 
2181 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2182 
2183 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2184  return hasDeviceType(getWaitOnly(), deviceType);
2185 }
2186 
2187 mlir::Operation::operand_range DataOp::getWaitValues() {
2188  return getWaitValues(mlir::acc::DeviceType::None);
2189 }
2190 
2192 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2194  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2195  getHasWaitDevnum(), deviceType);
2196 }
2197 
2198 mlir::Value DataOp::getWaitDevnum() {
2199  return getWaitDevnum(mlir::acc::DeviceType::None);
2200 }
2201 
2202 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2203  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2204  getWaitOperandsSegments(), getHasWaitDevnum(),
2205  deviceType);
2206 }
2207 
2208 //===----------------------------------------------------------------------===//
2209 // ExitDataOp
2210 //===----------------------------------------------------------------------===//
2211 
2212 LogicalResult acc::ExitDataOp::verify() {
2213  // 2.6.6. Data Exit Directive restriction
2214  // At least one copyout, delete, or detach clause must appear on an exit data
2215  // directive.
2216  if (getDataClauseOperands().empty())
2217  return emitError("at least one operand must be present in dataOperands on "
2218  "the exit data operation");
2219 
2220  // The async attribute represent the async clause without value. Therefore the
2221  // attribute and operand cannot appear at the same time.
2222  if (getAsyncOperand() && getAsync())
2223  return emitError("async attribute cannot appear with asyncOperand");
2224 
2225  // The wait attribute represent the wait clause without values. Therefore the
2226  // attribute and operands cannot appear at the same time.
2227  if (!getWaitOperands().empty() && getWait())
2228  return emitError("wait attribute cannot appear with waitOperands");
2229 
2230  if (getWaitDevnum() && getWaitOperands().empty())
2231  return emitError("wait_devnum cannot appear without waitOperands");
2232 
2233  return success();
2234 }
2235 
2236 unsigned ExitDataOp::getNumDataOperands() {
2237  return getDataClauseOperands().size();
2238 }
2239 
2240 Value ExitDataOp::getDataOperand(unsigned i) {
2241  unsigned numOptional = getIfCond() ? 1 : 0;
2242  numOptional += getAsyncOperand() ? 1 : 0;
2243  numOptional += getWaitDevnum() ? 1 : 0;
2244  return getOperand(getWaitOperands().size() + numOptional + i);
2245 }
2246 
2247 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2248  MLIRContext *context) {
2249  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2250 }
2251 
2252 //===----------------------------------------------------------------------===//
2253 // EnterDataOp
2254 //===----------------------------------------------------------------------===//
2255 
2256 LogicalResult acc::EnterDataOp::verify() {
2257  // 2.6.6. Data Enter Directive restriction
2258  // At least one copyin, create, or attach clause must appear on an enter data
2259  // directive.
2260  if (getDataClauseOperands().empty())
2261  return emitError("at least one operand must be present in dataOperands on "
2262  "the enter data operation");
2263 
2264  // The async attribute represent the async clause without value. Therefore the
2265  // attribute and operand cannot appear at the same time.
2266  if (getAsyncOperand() && getAsync())
2267  return emitError("async attribute cannot appear with asyncOperand");
2268 
2269  // The wait attribute represent the wait clause without values. Therefore the
2270  // attribute and operands cannot appear at the same time.
2271  if (!getWaitOperands().empty() && getWait())
2272  return emitError("wait attribute cannot appear with waitOperands");
2273 
2274  if (getWaitDevnum() && getWaitOperands().empty())
2275  return emitError("wait_devnum cannot appear without waitOperands");
2276 
2277  for (mlir::Value operand : getDataClauseOperands())
2278  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2279  operand.getDefiningOp()))
2280  return emitError("expect data entry operation as defining op");
2281 
2282  return success();
2283 }
2284 
2285 unsigned EnterDataOp::getNumDataOperands() {
2286  return getDataClauseOperands().size();
2287 }
2288 
2289 Value EnterDataOp::getDataOperand(unsigned i) {
2290  unsigned numOptional = getIfCond() ? 1 : 0;
2291  numOptional += getAsyncOperand() ? 1 : 0;
2292  numOptional += getWaitDevnum() ? 1 : 0;
2293  return getOperand(getWaitOperands().size() + numOptional + i);
2294 }
2295 
2296 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2297  MLIRContext *context) {
2298  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2299 }
2300 
2301 //===----------------------------------------------------------------------===//
2302 // AtomicReadOp
2303 //===----------------------------------------------------------------------===//
2304 
2305 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2306 
2307 //===----------------------------------------------------------------------===//
2308 // AtomicWriteOp
2309 //===----------------------------------------------------------------------===//
2310 
2311 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2312 
2313 //===----------------------------------------------------------------------===//
2314 // AtomicUpdateOp
2315 //===----------------------------------------------------------------------===//
2316 
2317 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2318  PatternRewriter &rewriter) {
2319  if (op.isNoOp()) {
2320  rewriter.eraseOp(op);
2321  return success();
2322  }
2323 
2324  if (Value writeVal = op.getWriteOpVal()) {
2325  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2326  return success();
2327  }
2328 
2329  return failure();
2330 }
2331 
2332 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2333 
2334 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2335 
2336 //===----------------------------------------------------------------------===//
2337 // AtomicCaptureOp
2338 //===----------------------------------------------------------------------===//
2339 
2340 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2341  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2342  return op;
2343  return dyn_cast<AtomicReadOp>(getSecondOp());
2344 }
2345 
2346 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2347  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2348  return op;
2349  return dyn_cast<AtomicWriteOp>(getSecondOp());
2350 }
2351 
2352 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2353  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2354  return op;
2355  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2356 }
2357 
2358 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2359 
2360 //===----------------------------------------------------------------------===//
2361 // DeclareEnterOp
2362 //===----------------------------------------------------------------------===//
2363 
2364 template <typename Op>
2365 static LogicalResult
2367  bool requireAtLeastOneOperand = true) {
2368  if (operands.empty() && requireAtLeastOneOperand)
2369  return emitError(
2370  op->getLoc(),
2371  "at least one operand must appear on the declare operation");
2372 
2373  for (mlir::Value operand : operands) {
2374  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2375  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2376  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2377  operand.getDefiningOp()))
2378  return op.emitError(
2379  "expect valid declare data entry operation or acc.getdeviceptr "
2380  "as defining op");
2381 
2382  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2383  assert(varPtr && "declare operands can only be data entry operations which "
2384  "must have varPtr");
2385  std::optional<mlir::acc::DataClause> dataClauseOptional{
2386  getDataClause(operand.getDefiningOp())};
2387  assert(dataClauseOptional.has_value() &&
2388  "declare operands can only be data entry operations which must have "
2389  "dataClause");
2390 
2391  // If varPtr has no defining op - there is nothing to check further.
2392  if (!varPtr.getDefiningOp())
2393  continue;
2394 
2395  // Check that the varPtr has a declare attribute.
2396  auto declareAttribute{
2397  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2398  if (!declareAttribute)
2399  return op.emitError(
2400  "expect declare attribute on variable in declare operation");
2401 
2402  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2403  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2404  return op.emitError(
2405  "expect matching declare attribute on variable in declare operation");
2406 
2407  // If the variable is marked with implicit attribute, the matching declare
2408  // data action must also be marked implicit. The reverse is not checked
2409  // since implicit data action may be inserted to do actions like updating
2410  // device copy, in which case the variable is not necessarily implicitly
2411  // declare'd.
2412  if (declAttr.getImplicit() &&
2413  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2414  return op.emitError(
2415  "implicitness must match between declare op and flag on variable");
2416  }
2417 
2418  return success();
2419 }
2420 
2421 LogicalResult acc::DeclareEnterOp::verify() {
2422  return checkDeclareOperands(*this, this->getDataClauseOperands());
2423 }
2424 
2425 //===----------------------------------------------------------------------===//
2426 // DeclareExitOp
2427 //===----------------------------------------------------------------------===//
2428 
2429 LogicalResult acc::DeclareExitOp::verify() {
2430  if (getToken())
2431  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2432  /*requireAtLeastOneOperand=*/false);
2433  return checkDeclareOperands(*this, this->getDataClauseOperands());
2434 }
2435 
2436 //===----------------------------------------------------------------------===//
2437 // DeclareOp
2438 //===----------------------------------------------------------------------===//
2439 
2440 LogicalResult acc::DeclareOp::verify() {
2441  return checkDeclareOperands(*this, this->getDataClauseOperands());
2442 }
2443 
2444 //===----------------------------------------------------------------------===//
2445 // RoutineOp
2446 //===----------------------------------------------------------------------===//
2447 
2448 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2449  acc::DeviceType dtype) {
2450  unsigned parallelism = 0;
2451  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2452  parallelism += op.hasWorker(dtype) ? 1 : 0;
2453  parallelism += op.hasVector(dtype) ? 1 : 0;
2454  parallelism += op.hasSeq(dtype) ? 1 : 0;
2455  return parallelism;
2456 }
2457 
2458 LogicalResult acc::RoutineOp::verify() {
2459  unsigned baseParallelism =
2461 
2462  if (baseParallelism > 1)
2463  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2464  "be present at the same time";
2465 
2466  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2467  ++dtypeInt) {
2468  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2469  if (dtype == acc::DeviceType::None)
2470  continue;
2471  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2472 
2473  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2474  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2475  "be present at the same time";
2476  }
2477 
2478  return success();
2479 }
2480 
2481 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2482  mlir::ArrayAttr &deviceTypes) {
2483  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2484  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2485 
2486  if (failed(parser.parseCommaSeparatedList([&]() {
2487  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2488  return failure();
2489  if (failed(parser.parseOptionalLSquare())) {
2490  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2491  parser.getContext(), mlir::acc::DeviceType::None));
2492  } else {
2493  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2494  parser.parseRSquare())
2495  return failure();
2496  }
2497  return success();
2498  })))
2499  return failure();
2500 
2501  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2502  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2503 
2504  return success();
2505 }
2506 
2508  std::optional<mlir::ArrayAttr> bindName,
2509  std::optional<mlir::ArrayAttr> deviceTypes) {
2510  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2511  [&](const auto &pair) {
2512  p << std::get<0>(pair);
2513  printSingleDeviceType(p, std::get<1>(pair));
2514  });
2515 }
2516 
2517 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2518  mlir::ArrayAttr &gang,
2519  mlir::ArrayAttr &gangDim,
2520  mlir::ArrayAttr &gangDimDeviceTypes) {
2521 
2522  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2523  gangDimDeviceTypeAttrs;
2524  bool needCommaBeforeOperands = false;
2525 
2526  // Gang keyword only
2527  if (failed(parser.parseOptionalLParen())) {
2528  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2530  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2531  return success();
2532  }
2533 
2534  // Parse keyword only attributes
2535  if (succeeded(parser.parseOptionalLSquare())) {
2536  if (failed(parser.parseCommaSeparatedList([&]() {
2537  if (parser.parseAttribute(gangAttrs.emplace_back()))
2538  return failure();
2539  return success();
2540  })))
2541  return failure();
2542  if (parser.parseRSquare())
2543  return failure();
2544  needCommaBeforeOperands = true;
2545  }
2546 
2547  if (needCommaBeforeOperands && failed(parser.parseComma()))
2548  return failure();
2549 
2550  if (failed(parser.parseCommaSeparatedList([&]() {
2551  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2552  parser.parseColon() ||
2553  parser.parseAttribute(gangDimAttrs.emplace_back()))
2554  return failure();
2555  if (succeeded(parser.parseOptionalLSquare())) {
2556  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2557  parser.parseRSquare())
2558  return failure();
2559  } else {
2560  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2561  parser.getContext(), mlir::acc::DeviceType::None));
2562  }
2563  return success();
2564  })))
2565  return failure();
2566 
2567  if (failed(parser.parseRParen()))
2568  return failure();
2569 
2570  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2571  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2572  gangDimDeviceTypes =
2573  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2574 
2575  return success();
2576 }
2577 
2579  std::optional<mlir::ArrayAttr> gang,
2580  std::optional<mlir::ArrayAttr> gangDim,
2581  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2582 
2583  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2584  gang->size() == 1) {
2585  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2586  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2587  return;
2588  }
2589 
2590  p << "(";
2591 
2592  printDeviceTypes(p, gang);
2593 
2594  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2595  p << ", ";
2596 
2597  if (hasDeviceTypeValues(gangDimDeviceTypes))
2598  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2599  [&](const auto &pair) {
2600  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2601  p << std::get<0>(pair);
2602  printSingleDeviceType(p, std::get<1>(pair));
2603  });
2604 
2605  p << ")";
2606 }
2607 
2608 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2609  mlir::ArrayAttr &deviceTypes) {
2611  // Keyword only
2612  if (failed(parser.parseOptionalLParen())) {
2613  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2615  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2616  return success();
2617  }
2618 
2619  // Parse device type attributes
2620  if (succeeded(parser.parseOptionalLSquare())) {
2621  if (failed(parser.parseCommaSeparatedList([&]() {
2622  if (parser.parseAttribute(attributes.emplace_back()))
2623  return failure();
2624  return success();
2625  })))
2626  return failure();
2627  if (parser.parseRSquare() || parser.parseRParen())
2628  return failure();
2629  }
2630  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2631  return success();
2632 }
2633 
2634 static void
2636  std::optional<mlir::ArrayAttr> deviceTypes) {
2637 
2638  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2639  auto deviceTypeAttr =
2640  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2641  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2642  return;
2643  }
2644 
2645  if (!hasDeviceTypeValues(deviceTypes))
2646  return;
2647 
2648  p << "([";
2649  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2650  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2651  p << dTypeAttr;
2652  });
2653  p << "])";
2654 }
2655 
2656 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2657 
2658 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2659  return hasDeviceType(getWorker(), deviceType);
2660 }
2661 
2662 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2663 
2664 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2665  return hasDeviceType(getVector(), deviceType);
2666 }
2667 
2668 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2669 
2670 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2671  return hasDeviceType(getSeq(), deviceType);
2672 }
2673 
2674 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2675  return getBindNameValue(mlir::acc::DeviceType::None);
2676 }
2677 
2678 std::optional<llvm::StringRef>
2679 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2680  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2681  return std::nullopt;
2682  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2683  auto attr = (*getBindName())[*pos];
2684  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2685  return stringAttr.getValue();
2686  }
2687  return std::nullopt;
2688 }
2689 
2690 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2691 
2692 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2693  return hasDeviceType(getGang(), deviceType);
2694 }
2695 
2696 std::optional<int64_t> RoutineOp::getGangDimValue() {
2697  return getGangDimValue(mlir::acc::DeviceType::None);
2698 }
2699 
2700 std::optional<int64_t>
2701 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2702  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2703  return std::nullopt;
2704  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2705  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2706  return intAttr.getInt();
2707  }
2708  return std::nullopt;
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // InitOp
2713 //===----------------------------------------------------------------------===//
2714 
2715 LogicalResult acc::InitOp::verify() {
2716  Operation *currOp = *this;
2717  while ((currOp = currOp->getParentOp()))
2718  if (isComputeOperation(currOp))
2719  return emitOpError("cannot be nested in a compute operation");
2720  return success();
2721 }
2722 
2723 //===----------------------------------------------------------------------===//
2724 // ShutdownOp
2725 //===----------------------------------------------------------------------===//
2726 
2727 LogicalResult acc::ShutdownOp::verify() {
2728  Operation *currOp = *this;
2729  while ((currOp = currOp->getParentOp()))
2730  if (isComputeOperation(currOp))
2731  return emitOpError("cannot be nested in a compute operation");
2732  return success();
2733 }
2734 
2735 //===----------------------------------------------------------------------===//
2736 // SetOp
2737 //===----------------------------------------------------------------------===//
2738 
2739 LogicalResult acc::SetOp::verify() {
2740  Operation *currOp = *this;
2741  while ((currOp = currOp->getParentOp()))
2742  if (isComputeOperation(currOp))
2743  return emitOpError("cannot be nested in a compute operation");
2744  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2745  return emitOpError("at least one default_async, device_num, or device_type "
2746  "operand must appear");
2747  return success();
2748 }
2749 
2750 //===----------------------------------------------------------------------===//
2751 // UpdateOp
2752 //===----------------------------------------------------------------------===//
2753 
2754 LogicalResult acc::UpdateOp::verify() {
2755  // At least one of host or device should have a value.
2756  if (getDataClauseOperands().empty())
2757  return emitError("at least one value must be present in dataOperands");
2758 
2759  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2760  getAsyncOperandsDeviceTypeAttr(),
2761  "async")))
2762  return failure();
2763 
2765  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2766  getWaitOperandsDeviceTypeAttr(), "wait")))
2767  return failure();
2768 
2769  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2770  return failure();
2771 
2772  for (mlir::Value operand : getDataClauseOperands())
2773  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2774  operand.getDefiningOp()))
2775  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2776  "as defining op");
2777 
2778  return success();
2779 }
2780 
2781 unsigned UpdateOp::getNumDataOperands() {
2782  return getDataClauseOperands().size();
2783 }
2784 
2785 Value UpdateOp::getDataOperand(unsigned i) {
2786  unsigned numOptional = getAsyncOperands().size();
2787  numOptional += getIfCond() ? 1 : 0;
2788  return getOperand(getWaitOperands().size() + numOptional + i);
2789 }
2790 
2791 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2792  MLIRContext *context) {
2793  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2794 }
2795 
2796 bool UpdateOp::hasAsyncOnly() {
2797  return hasAsyncOnly(mlir::acc::DeviceType::None);
2798 }
2799 
2800 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2801  return hasDeviceType(getAsync(), deviceType);
2802 }
2803 
2804 mlir::Value UpdateOp::getAsyncValue() {
2805  return getAsyncValue(mlir::acc::DeviceType::None);
2806 }
2807 
2808 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2810  return {};
2811 
2812  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2813  return getAsyncOperands()[*pos];
2814 
2815  return {};
2816 }
2817 
2818 bool UpdateOp::hasWaitOnly() {
2819  return hasWaitOnly(mlir::acc::DeviceType::None);
2820 }
2821 
2822 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2823  return hasDeviceType(getWaitOnly(), deviceType);
2824 }
2825 
2826 mlir::Operation::operand_range UpdateOp::getWaitValues() {
2827  return getWaitValues(mlir::acc::DeviceType::None);
2828 }
2829 
2831 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2833  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2834  getHasWaitDevnum(), deviceType);
2835 }
2836 
2837 mlir::Value UpdateOp::getWaitDevnum() {
2838  return getWaitDevnum(mlir::acc::DeviceType::None);
2839 }
2840 
2841 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2842  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2843  getWaitOperandsSegments(), getHasWaitDevnum(),
2844  deviceType);
2845 }
2846 
2847 //===----------------------------------------------------------------------===//
2848 // WaitOp
2849 //===----------------------------------------------------------------------===//
2850 
2851 LogicalResult acc::WaitOp::verify() {
2852  // The async attribute represent the async clause without value. Therefore the
2853  // attribute and operand cannot appear at the same time.
2854  if (getAsyncOperand() && getAsync())
2855  return emitError("async attribute cannot appear with asyncOperand");
2856 
2857  if (getWaitDevnum() && getWaitOperands().empty())
2858  return emitError("wait_devnum cannot appear without waitOperands");
2859 
2860  return success();
2861 }
2862 
2863 #define GET_OP_CLASSES
2864 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2865 
2866 #define GET_ATTRDEF_CLASSES
2867 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2868 
2869 #define GET_TYPEDEF_CLASSES
2870 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2871 
2872 //===----------------------------------------------------------------------===//
2873 // acc dialect utilities
2874 //===----------------------------------------------------------------------===//
2875 
2877  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2878  .Case<ACC_DATA_ENTRY_OPS>(
2879  [&](auto entry) { return entry.getVarPtr(); })
2880  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2881  [&](auto exit) { return exit.getVarPtr(); })
2882  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2883  return varPtr;
2884 }
2885 
2887  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2889  [&](auto dataClause) { return dataClause.getAccPtr(); })
2890  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2891  return accPtr;
2892 }
2893 
2895  auto varPtrPtr{
2897  .Case<ACC_DATA_ENTRY_OPS>(
2898  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2899  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2900  return varPtrPtr;
2901 }
2902 
2907  accDataClauseOp)
2908  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2910  dataClause.getBounds().begin(), dataClause.getBounds().end());
2911  })
2912  .Default([&](mlir::Operation *) {
2914  })};
2915  return bounds;
2916 }
2917 
2921  accDataClauseOp)
2922  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2924  dataClause.getAsyncOperands().begin(),
2925  dataClause.getAsyncOperands().end());
2926  })
2927  .Default([&](mlir::Operation *) {
2929  });
2930 }
2931 
2932 mlir::ArrayAttr
2935  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2936  return dataClause.getAsyncOperandsDeviceTypeAttr();
2937  })
2938  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2939 }
2940 
2941 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
2944  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
2945  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2946 }
2947 
2948 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2949  auto name{
2951  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2952  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2953  return {};
2954  })};
2955  return name;
2956 }
2957 
2958 std::optional<mlir::acc::DataClause>
2960  auto dataClause{
2962  accDataEntryOp)
2963  .Case<ACC_DATA_ENTRY_OPS>(
2964  [&](auto entry) { return entry.getDataClause(); })
2965  .Default([&](mlir::Operation *) { return std::nullopt; })};
2966  return dataClause;
2967 }
2968 
2970  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
2971  .Case<ACC_DATA_ENTRY_OPS>(
2972  [&](auto entry) { return entry.getImplicit(); })
2973  .Default([&](mlir::Operation *) { return false; })};
2974  return implicit;
2975 }
2976 
2978  auto dataOperands{
2981  [&](auto entry) { return entry.getDataClauseOperands(); })
2982  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
2983  return dataOperands;
2984 }
2985 
2988  auto dataOperands{
2991  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
2992  .Default([&](mlir::Operation *) { return nullptr; })};
2993  return dataOperands;
2994 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2247
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2578
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:446
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1797
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:748
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:1811
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:460
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1155
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2481
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1166
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1071
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:77
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2635
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1606
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1320
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2366
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:97
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2082
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:677
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1200
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:830
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:121
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:941
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2113
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2608
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2517
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1054
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1227
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2507
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1008
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:661
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:153
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1625
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:536
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:985
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:108
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:692
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1300
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:83
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1752
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:137
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1238
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:173
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:758
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2448
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:991
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1346
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:641
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:43
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:51
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:215
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2876
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:2959
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:2987
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:2904
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:2977
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2948
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:2969
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:2919
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:2894
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2886
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2941
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:140
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2933
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.