MLIR  20.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 using namespace acc;
24 
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
29 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
30 
31 namespace {
32 struct MemRefPointerLikeModel
33  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
34  MemRefType> {
35  Type getElementType(Type pointer) const {
36  return llvm::cast<MemRefType>(pointer).getElementType();
37  }
38 };
39 
40 struct LLVMPointerPointerLikeModel
41  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
42  LLVM::LLVMPointerType> {
43  Type getElementType(Type pointer) const { return Type(); }
44 };
45 } // namespace
46 
47 //===----------------------------------------------------------------------===//
48 // OpenACC operations
49 //===----------------------------------------------------------------------===//
50 
51 void OpenACCDialect::initialize() {
52  addOperations<
53 #define GET_OP_LIST
54 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
55  >();
56  addAttributes<
57 #define GET_ATTRDEF_LIST
58 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
59  >();
60  addTypes<
61 #define GET_TYPEDEF_LIST
62 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
63  >();
64 
65  // By attaching interfaces here, we make the OpenACC dialect dependent on
66  // the other dialects. This is probably better than having dialects like LLVM
67  // and memref be dependent on OpenACC.
68  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
69  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
70  *getContext());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // device_type support helpers
75 //===----------------------------------------------------------------------===//
76 
77 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
78  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
79  return true;
80  return false;
81 }
82 
83 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
84  mlir::acc::DeviceType deviceType) {
85  if (!hasDeviceTypeValues(arrayAttr))
86  return false;
87 
88  for (auto attr : *arrayAttr) {
89  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
90  if (deviceTypeAttr.getValue() == deviceType)
91  return true;
92  }
93 
94  return false;
95 }
96 
98  std::optional<mlir::ArrayAttr> deviceTypes) {
99  if (!hasDeviceTypeValues(deviceTypes))
100  return;
101 
102  p << "[";
103  llvm::interleaveComma(*deviceTypes, p,
104  [&](mlir::Attribute attr) { p << attr; });
105  p << "]";
106 }
107 
108 static std::optional<unsigned> findSegment(ArrayAttr segments,
109  mlir::acc::DeviceType deviceType) {
110  unsigned segmentIdx = 0;
111  for (auto attr : segments) {
112  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
113  if (deviceTypeAttr.getValue() == deviceType)
114  return std::make_optional(segmentIdx);
115  ++segmentIdx;
116  }
117  return std::nullopt;
118 }
119 
121 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
123  std::optional<llvm::ArrayRef<int32_t>> segments,
124  mlir::acc::DeviceType deviceType) {
125  if (!arrayAttr)
126  return range.take_front(0);
127  if (auto pos = findSegment(*arrayAttr, deviceType)) {
128  int32_t nbOperandsBefore = 0;
129  for (unsigned i = 0; i < *pos; ++i)
130  nbOperandsBefore += (*segments)[i];
131  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
132  }
133  return range.take_front(0);
134 }
135 
136 static mlir::Value
137 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
139  std::optional<llvm::ArrayRef<int32_t>> segments,
140  std::optional<mlir::ArrayAttr> hasWaitDevnum,
141  mlir::acc::DeviceType deviceType) {
142  if (!hasDeviceTypeValues(deviceTypeAttr))
143  return {};
144  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
145  if (hasWaitDevnum->getValue()[*pos])
146  return getValuesFromSegments(deviceTypeAttr, operands, segments,
147  deviceType)
148  .front();
149  return {};
150 }
151 
153 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
155  std::optional<llvm::ArrayRef<int32_t>> segments,
156  std::optional<mlir::ArrayAttr> hasWaitDevnum,
157  mlir::acc::DeviceType deviceType) {
158  auto range =
159  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
160  if (range.empty())
161  return range;
162  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
163  if (hasWaitDevnum && *hasWaitDevnum) {
164  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
165  if (boolAttr.getValue())
166  return range.drop_front(1); // first value is devnum
167  }
168  }
169  return range;
170 }
171 
172 template <typename Op>
173 static LogicalResult checkWaitAndAsyncConflict(Op op) {
174  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
175  ++dtypeInt) {
176  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
177 
178  // The async attribute represent the async clause without value. Therefore
179  // the attribute and operand cannot appear at the same time.
180  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
181  op.hasAsyncOnly(dtype))
182  return op.emitError("async attribute cannot appear with asyncOperand");
183 
184  // The wait attribute represent the wait clause without values. Therefore
185  // the attribute and operands cannot appear at the same time.
186  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
187  op.hasWaitOnly(dtype))
188  return op.emitError("wait attribute cannot appear with waitOperands");
189  }
190  return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // DataBoundsOp
195 //===----------------------------------------------------------------------===//
196 LogicalResult acc::DataBoundsOp::verify() {
197  auto extent = getExtent();
198  auto upperbound = getUpperbound();
199  if (!extent && !upperbound)
200  return emitError("expected extent or upperbound.");
201  return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // PrivateOp
206 //===----------------------------------------------------------------------===//
207 LogicalResult acc::PrivateOp::verify() {
208  if (getDataClause() != acc::DataClause::acc_private)
209  return emitError(
210  "data clause associated with private operation must match its intent");
211  return success();
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // FirstprivateOp
216 //===----------------------------------------------------------------------===//
217 LogicalResult acc::FirstprivateOp::verify() {
218  if (getDataClause() != acc::DataClause::acc_firstprivate)
219  return emitError("data clause associated with firstprivate operation must "
220  "match its intent");
221  return success();
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // ReductionOp
226 //===----------------------------------------------------------------------===//
227 LogicalResult acc::ReductionOp::verify() {
228  if (getDataClause() != acc::DataClause::acc_reduction)
229  return emitError("data clause associated with reduction operation must "
230  "match its intent");
231  return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // DevicePtrOp
236 //===----------------------------------------------------------------------===//
237 LogicalResult acc::DevicePtrOp::verify() {
238  if (getDataClause() != acc::DataClause::acc_deviceptr)
239  return emitError("data clause associated with deviceptr operation must "
240  "match its intent");
241  return success();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // PresentOp
246 //===----------------------------------------------------------------------===//
247 LogicalResult acc::PresentOp::verify() {
248  if (getDataClause() != acc::DataClause::acc_present)
249  return emitError(
250  "data clause associated with present operation must match its intent");
251  return success();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // CopyinOp
256 //===----------------------------------------------------------------------===//
257 LogicalResult acc::CopyinOp::verify() {
258  // Test for all clauses this operation can be decomposed from:
259  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
260  getDataClause() != acc::DataClause::acc_copyin_readonly &&
261  getDataClause() != acc::DataClause::acc_copy &&
262  getDataClause() != acc::DataClause::acc_reduction)
263  return emitError(
264  "data clause associated with copyin operation must match its intent"
265  " or specify original clause this operation was decomposed from");
266  return success();
267 }
268 
269 bool acc::CopyinOp::isCopyinReadonly() {
270  return getDataClause() == acc::DataClause::acc_copyin_readonly;
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // CreateOp
275 //===----------------------------------------------------------------------===//
276 LogicalResult acc::CreateOp::verify() {
277  // Test for all clauses this operation can be decomposed from:
278  if (getDataClause() != acc::DataClause::acc_create &&
279  getDataClause() != acc::DataClause::acc_create_zero &&
280  getDataClause() != acc::DataClause::acc_copyout &&
281  getDataClause() != acc::DataClause::acc_copyout_zero)
282  return emitError(
283  "data clause associated with create operation must match its intent"
284  " or specify original clause this operation was decomposed from");
285  return success();
286 }
287 
288 bool acc::CreateOp::isCreateZero() {
289  // The zero modifier is encoded in the data clause.
290  return getDataClause() == acc::DataClause::acc_create_zero ||
291  getDataClause() == acc::DataClause::acc_copyout_zero;
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // NoCreateOp
296 //===----------------------------------------------------------------------===//
297 LogicalResult acc::NoCreateOp::verify() {
298  if (getDataClause() != acc::DataClause::acc_no_create)
299  return emitError("data clause associated with no_create operation must "
300  "match its intent");
301  return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // AttachOp
306 //===----------------------------------------------------------------------===//
307 LogicalResult acc::AttachOp::verify() {
308  if (getDataClause() != acc::DataClause::acc_attach)
309  return emitError(
310  "data clause associated with attach operation must match its intent");
311  return success();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // DeclareDeviceResidentOp
316 //===----------------------------------------------------------------------===//
317 
318 LogicalResult acc::DeclareDeviceResidentOp::verify() {
319  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
320  return emitError("data clause associated with device_resident operation "
321  "must match its intent");
322  return success();
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // DeclareLinkOp
327 //===----------------------------------------------------------------------===//
328 
329 LogicalResult acc::DeclareLinkOp::verify() {
330  if (getDataClause() != acc::DataClause::acc_declare_link)
331  return emitError(
332  "data clause associated with link operation must match its intent");
333  return success();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // CopyoutOp
338 //===----------------------------------------------------------------------===//
339 LogicalResult acc::CopyoutOp::verify() {
340  // Test for all clauses this operation can be decomposed from:
341  if (getDataClause() != acc::DataClause::acc_copyout &&
342  getDataClause() != acc::DataClause::acc_copyout_zero &&
343  getDataClause() != acc::DataClause::acc_copy &&
344  getDataClause() != acc::DataClause::acc_reduction)
345  return emitError(
346  "data clause associated with copyout operation must match its intent"
347  " or specify original clause this operation was decomposed from");
348  if (!getVarPtr() || !getAccPtr())
349  return emitError("must have both host and device pointers");
350  return success();
351 }
352 
353 bool acc::CopyoutOp::isCopyoutZero() {
354  return getDataClause() == acc::DataClause::acc_copyout_zero;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // DeleteOp
359 //===----------------------------------------------------------------------===//
360 LogicalResult acc::DeleteOp::verify() {
361  // Test for all clauses this operation can be decomposed from:
362  if (getDataClause() != acc::DataClause::acc_delete &&
363  getDataClause() != acc::DataClause::acc_create &&
364  getDataClause() != acc::DataClause::acc_create_zero &&
365  getDataClause() != acc::DataClause::acc_copyin &&
366  getDataClause() != acc::DataClause::acc_copyin_readonly &&
367  getDataClause() != acc::DataClause::acc_present &&
368  getDataClause() != acc::DataClause::acc_declare_device_resident &&
369  getDataClause() != acc::DataClause::acc_declare_link)
370  return emitError(
371  "data clause associated with delete operation must match its intent"
372  " or specify original clause this operation was decomposed from");
373  if (!getAccPtr())
374  return emitError("must have device pointer");
375  return success();
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // DetachOp
380 //===----------------------------------------------------------------------===//
381 LogicalResult acc::DetachOp::verify() {
382  // Test for all clauses this operation can be decomposed from:
383  if (getDataClause() != acc::DataClause::acc_detach &&
384  getDataClause() != acc::DataClause::acc_attach)
385  return emitError(
386  "data clause associated with detach operation must match its intent"
387  " or specify original clause this operation was decomposed from");
388  if (!getAccPtr())
389  return emitError("must have device pointer");
390  return success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // HostOp
395 //===----------------------------------------------------------------------===//
396 LogicalResult acc::UpdateHostOp::verify() {
397  // Test for all clauses this operation can be decomposed from:
398  if (getDataClause() != acc::DataClause::acc_update_host &&
399  getDataClause() != acc::DataClause::acc_update_self)
400  return emitError(
401  "data clause associated with host operation must match its intent"
402  " or specify original clause this operation was decomposed from");
403  if (!getVarPtr() || !getAccPtr())
404  return emitError("must have both host and device pointers");
405  return success();
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // DeviceOp
410 //===----------------------------------------------------------------------===//
411 LogicalResult acc::UpdateDeviceOp::verify() {
412  // Test for all clauses this operation can be decomposed from:
413  if (getDataClause() != acc::DataClause::acc_update_device)
414  return emitError(
415  "data clause associated with device operation must match its intent"
416  " or specify original clause this operation was decomposed from");
417  return success();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // UseDeviceOp
422 //===----------------------------------------------------------------------===//
423 LogicalResult acc::UseDeviceOp::verify() {
424  // Test for all clauses this operation can be decomposed from:
425  if (getDataClause() != acc::DataClause::acc_use_device)
426  return emitError(
427  "data clause associated with use_device operation must match its intent"
428  " or specify original clause this operation was decomposed from");
429  return success();
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // CacheOp
434 //===----------------------------------------------------------------------===//
435 LogicalResult acc::CacheOp::verify() {
436  // Test for all clauses this operation can be decomposed from:
437  if (getDataClause() != acc::DataClause::acc_cache &&
438  getDataClause() != acc::DataClause::acc_cache_readonly)
439  return emitError(
440  "data clause associated with cache operation must match its intent"
441  " or specify original clause this operation was decomposed from");
442  return success();
443 }
444 
445 template <typename StructureOp>
446 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
447  unsigned nRegions = 1) {
448 
449  SmallVector<Region *, 2> regions;
450  for (unsigned i = 0; i < nRegions; ++i)
451  regions.push_back(state.addRegion());
452 
453  for (Region *region : regions)
454  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
455  return failure();
456 
457  return success();
458 }
459 
460 static bool isComputeOperation(Operation *op) {
461  return isa<acc::ParallelOp, acc::LoopOp>(op);
462 }
463 
464 namespace {
465 /// Pattern to remove operation without region that have constant false `ifCond`
466 /// and remove the condition from the operation if the `ifCond` is a true
467 /// constant.
468 template <typename OpTy>
469 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
471 
472  LogicalResult matchAndRewrite(OpTy op,
473  PatternRewriter &rewriter) const override {
474  // Early return if there is no condition.
475  Value ifCond = op.getIfCond();
476  if (!ifCond)
477  return failure();
478 
479  IntegerAttr constAttr;
480  if (!matchPattern(ifCond, m_Constant(&constAttr)))
481  return failure();
482  if (constAttr.getInt())
483  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
484  else
485  rewriter.eraseOp(op);
486 
487  return success();
488  }
489 };
490 
491 /// Replaces the given op with the contents of the given single-block region,
492 /// using the operands of the block terminator to replace operation results.
493 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
494  Region &region, ValueRange blockArgs = {}) {
495  assert(llvm::hasSingleElement(region) && "expected single-region block");
496  Block *block = &region.front();
497  Operation *terminator = block->getTerminator();
498  ValueRange results = terminator->getOperands();
499  rewriter.inlineBlockBefore(block, op, blockArgs);
500  rewriter.replaceOp(op, results);
501  rewriter.eraseOp(terminator);
502 }
503 
504 /// Pattern to remove operation with region that have constant false `ifCond`
505 /// and remove the condition from the operation if the `ifCond` is constant
506 /// true.
507 template <typename OpTy>
508 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
510 
511  LogicalResult matchAndRewrite(OpTy op,
512  PatternRewriter &rewriter) const override {
513  // Early return if there is no condition.
514  Value ifCond = op.getIfCond();
515  if (!ifCond)
516  return failure();
517 
518  IntegerAttr constAttr;
519  if (!matchPattern(ifCond, m_Constant(&constAttr)))
520  return failure();
521  if (constAttr.getInt())
522  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
523  else
524  replaceOpWithRegion(rewriter, op, op.getRegion());
525 
526  return success();
527  }
528 };
529 
530 } // namespace
531 
532 //===----------------------------------------------------------------------===//
533 // PrivateRecipeOp
534 //===----------------------------------------------------------------------===//
535 
536 static LogicalResult verifyInitLikeSingleArgRegion(
537  Operation *op, Region &region, StringRef regionType, StringRef regionName,
538  Type type, bool verifyYield, bool optional = false) {
539  if (optional && region.empty())
540  return success();
541 
542  if (region.empty())
543  return op->emitOpError() << "expects non-empty " << regionName << " region";
544  Block &firstBlock = region.front();
545  if (firstBlock.getNumArguments() < 1 ||
546  firstBlock.getArgument(0).getType() != type)
547  return op->emitOpError() << "expects " << regionName
548  << " region first "
549  "argument of the "
550  << regionType << " type";
551 
552  if (verifyYield) {
553  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
554  if (yieldOp.getOperands().size() != 1 ||
555  yieldOp.getOperands().getTypes()[0] != type)
556  return op->emitOpError() << "expects " << regionName
557  << " region to "
558  "yield a value of the "
559  << regionType << " type";
560  }
561  }
562  return success();
563 }
564 
565 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
566  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
567  "privatization", "init", getType(),
568  /*verifyYield=*/false)))
569  return failure();
571  *this, getDestroyRegion(), "privatization", "destroy", getType(),
572  /*verifyYield=*/false, /*optional=*/true)))
573  return failure();
574  return success();
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // FirstprivateRecipeOp
579 //===----------------------------------------------------------------------===//
580 
581 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
582  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
583  "privatization", "init", getType(),
584  /*verifyYield=*/false)))
585  return failure();
586 
587  if (getCopyRegion().empty())
588  return emitOpError() << "expects non-empty copy region";
589 
590  Block &firstBlock = getCopyRegion().front();
591  if (firstBlock.getNumArguments() < 2 ||
592  firstBlock.getArgument(0).getType() != getType())
593  return emitOpError() << "expects copy region with two arguments of the "
594  "privatization type";
595 
596  if (getDestroyRegion().empty())
597  return success();
598 
599  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
600  "privatization", "destroy",
601  getType(), /*verifyYield=*/false)))
602  return failure();
603 
604  return success();
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // ReductionRecipeOp
609 //===----------------------------------------------------------------------===//
610 
611 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
612  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
613  "init", getType(),
614  /*verifyYield=*/false)))
615  return failure();
616 
617  if (getCombinerRegion().empty())
618  return emitOpError() << "expects non-empty combiner region";
619 
620  Block &reductionBlock = getCombinerRegion().front();
621  if (reductionBlock.getNumArguments() < 2 ||
622  reductionBlock.getArgument(0).getType() != getType() ||
623  reductionBlock.getArgument(1).getType() != getType())
624  return emitOpError() << "expects combiner region with the first two "
625  << "arguments of the reduction type";
626 
627  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
628  if (yieldOp.getOperands().size() != 1 ||
629  yieldOp.getOperands().getTypes()[0] != getType())
630  return emitOpError() << "expects combiner region to yield a value "
631  "of the reduction type";
632  }
633 
634  return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Custom parser and printer verifier for private clause
639 //===----------------------------------------------------------------------===//
640 
641 static ParseResult parseSymOperandList(
642  mlir::OpAsmParser &parser,
644  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
646  if (failed(parser.parseCommaSeparatedList([&]() {
647  if (parser.parseAttribute(attributes.emplace_back()) ||
648  parser.parseArrow() ||
649  parser.parseOperand(operands.emplace_back()) ||
650  parser.parseColonType(types.emplace_back()))
651  return failure();
652  return success();
653  })))
654  return failure();
655  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
656  attributes.end());
657  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
658  return success();
659 }
660 
662  mlir::OperandRange operands,
663  mlir::TypeRange types,
664  std::optional<mlir::ArrayAttr> attributes) {
665  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
666  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
667  << std::get<1>(it).getType();
668  });
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // ParallelOp
673 //===----------------------------------------------------------------------===//
674 
675 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
676 template <typename Op>
677 static LogicalResult checkDataOperands(Op op,
678  const mlir::ValueRange &operands) {
679  for (mlir::Value operand : operands)
680  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
681  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
682  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
683  operand.getDefiningOp()))
684  return op.emitError(
685  "expect data entry/exit operation or acc.getdeviceptr "
686  "as defining op");
687  return success();
688 }
689 
690 template <typename Op>
691 static LogicalResult
692 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
693  mlir::OperandRange operands, llvm::StringRef operandName,
694  llvm::StringRef symbolName, bool checkOperandType = true) {
695  if (!operands.empty()) {
696  if (!attributes || attributes->size() != operands.size())
697  return op->emitOpError()
698  << "expected as many " << symbolName << " symbol reference as "
699  << operandName << " operands";
700  } else {
701  if (attributes)
702  return op->emitOpError()
703  << "unexpected " << symbolName << " symbol reference";
704  return success();
705  }
706 
708  for (auto args : llvm::zip(operands, *attributes)) {
709  mlir::Value operand = std::get<0>(args);
710 
711  if (!set.insert(operand).second)
712  return op->emitOpError()
713  << operandName << " operand appears more than once";
714 
715  mlir::Type varType = operand.getType();
716  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
717  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
718  if (!decl)
719  return op->emitOpError()
720  << "expected symbol reference " << symbolRef << " to point to a "
721  << operandName << " declaration";
722 
723  if (checkOperandType && decl.getType() && decl.getType() != varType)
724  return op->emitOpError() << "expected " << operandName << " (" << varType
725  << ") to be the same type as " << operandName
726  << " declaration (" << decl.getType() << ")";
727  }
728 
729  return success();
730 }
731 
732 unsigned ParallelOp::getNumDataOperands() {
733  return getReductionOperands().size() + getGangPrivateOperands().size() +
734  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
735 }
736 
737 Value ParallelOp::getDataOperand(unsigned i) {
738  unsigned numOptional = getAsyncOperands().size();
739  numOptional += getNumGangs().size();
740  numOptional += getNumWorkers().size();
741  numOptional += getVectorLength().size();
742  numOptional += getIfCond() ? 1 : 0;
743  numOptional += getSelfCond() ? 1 : 0;
744  return getOperand(getWaitOperands().size() + numOptional + i);
745 }
746 
747 template <typename Op>
748 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
749  ArrayAttr deviceTypes,
750  llvm::StringRef keyword) {
751  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
752  return op.emitOpError() << keyword << " operands count must match "
753  << keyword << " device_type count";
754  return success();
755 }
756 
757 template <typename Op>
759  Op op, OperandRange operands, DenseI32ArrayAttr segments,
760  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
761  std::size_t numOperandsInSegments = 0;
762 
763  if (!segments)
764  return success();
765 
766  for (auto segCount : segments.asArrayRef()) {
767  if (maxInSegment != 0 && segCount > maxInSegment)
768  return op.emitOpError() << keyword << " expects a maximum of "
769  << maxInSegment << " values per segment";
770  numOperandsInSegments += segCount;
771  }
772  if (numOperandsInSegments != operands.size())
773  return op.emitOpError()
774  << keyword << " operand count does not match count in segments";
775  if (deviceTypes.getValue().size() != (size_t)segments.size())
776  return op.emitOpError()
777  << keyword << " segment count does not match device_type count";
778  return success();
779 }
780 
781 LogicalResult acc::ParallelOp::verify() {
782  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
783  *this, getPrivatizations(), getGangPrivateOperands(), "private",
784  "privatizations", /*checkOperandType=*/false)))
785  return failure();
786  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
787  *this, getReductionRecipes(), getReductionOperands(), "reduction",
788  "reductions", false)))
789  return failure();
790 
792  *this, getNumGangs(), getNumGangsSegmentsAttr(),
793  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
794  return failure();
795 
797  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
798  getWaitOperandsDeviceTypeAttr(), "wait")))
799  return failure();
800 
801  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
802  getNumWorkersDeviceTypeAttr(),
803  "num_workers")))
804  return failure();
805 
806  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
807  getVectorLengthDeviceTypeAttr(),
808  "vector_length")))
809  return failure();
810 
811  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
812  getAsyncOperandsDeviceTypeAttr(),
813  "async")))
814  return failure();
815 
816  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
817  return failure();
818 
819  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
820 }
821 
822 static mlir::Value
823 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
825  mlir::acc::DeviceType deviceType) {
826  if (!arrayAttr)
827  return {};
828  if (auto pos = findSegment(*arrayAttr, deviceType))
829  return range[*pos];
830  return {};
831 }
832 
833 bool acc::ParallelOp::hasAsyncOnly() {
834  return hasAsyncOnly(mlir::acc::DeviceType::None);
835 }
836 
837 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
838  return hasDeviceType(getAsyncOnly(), deviceType);
839 }
840 
841 mlir::Value acc::ParallelOp::getAsyncValue() {
842  return getAsyncValue(mlir::acc::DeviceType::None);
843 }
844 
845 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
847  getAsyncOperands(), deviceType);
848 }
849 
850 mlir::Value acc::ParallelOp::getNumWorkersValue() {
851  return getNumWorkersValue(mlir::acc::DeviceType::None);
852 }
853 
855 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
856  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
857  deviceType);
858 }
859 
860 mlir::Value acc::ParallelOp::getVectorLengthValue() {
861  return getVectorLengthValue(mlir::acc::DeviceType::None);
862 }
863 
865 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
866  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
867  getVectorLength(), deviceType);
868 }
869 
870 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
871  return getNumGangsValues(mlir::acc::DeviceType::None);
872 }
873 
875 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
876  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
877  getNumGangsSegments(), deviceType);
878 }
879 
880 bool acc::ParallelOp::hasWaitOnly() {
881  return hasWaitOnly(mlir::acc::DeviceType::None);
882 }
883 
884 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
885  return hasDeviceType(getWaitOnly(), deviceType);
886 }
887 
888 mlir::Operation::operand_range ParallelOp::getWaitValues() {
889  return getWaitValues(mlir::acc::DeviceType::None);
890 }
891 
893 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
895  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
896  getHasWaitDevnum(), deviceType);
897 }
898 
899 mlir::Value ParallelOp::getWaitDevnum() {
900  return getWaitDevnum(mlir::acc::DeviceType::None);
901 }
902 
903 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
904  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
905  getWaitOperandsSegments(), getHasWaitDevnum(),
906  deviceType);
907 }
908 
909 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
910  mlir::OperationState &odsState,
911  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
912  mlir::ValueRange vectorLength,
913  mlir::ValueRange asyncOperands,
914  mlir::ValueRange waitOperands, mlir::Value ifCond,
915  mlir::Value selfCond, mlir::ValueRange reductionOperands,
916  mlir::ValueRange gangPrivateOperands,
917  mlir::ValueRange gangFirstPrivateOperands,
918  mlir::ValueRange dataClauseOperands) {
919 
920  ParallelOp::build(
921  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
922  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
923  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
924  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
925  /*numGangsDeviceType=*/nullptr, numWorkers,
926  /*numWorkersDeviceType=*/nullptr, vectorLength,
927  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
928  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
929  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
930  /*firstprivatizations=*/nullptr, dataClauseOperands,
931  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
932 }
933 
934 static ParseResult parseNumGangs(
935  mlir::OpAsmParser &parser,
937  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
938  mlir::DenseI32ArrayAttr &segments) {
941 
942  do {
943  if (failed(parser.parseLBrace()))
944  return failure();
945 
946  int32_t crtOperandsSize = operands.size();
947  if (failed(parser.parseCommaSeparatedList(
949  if (parser.parseOperand(operands.emplace_back()) ||
950  parser.parseColonType(types.emplace_back()))
951  return failure();
952  return success();
953  })))
954  return failure();
955  seg.push_back(operands.size() - crtOperandsSize);
956 
957  if (failed(parser.parseRBrace()))
958  return failure();
959 
960  if (succeeded(parser.parseOptionalLSquare())) {
961  if (parser.parseAttribute(attributes.emplace_back()) ||
962  parser.parseRSquare())
963  return failure();
964  } else {
965  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
967  }
968  } while (succeeded(parser.parseOptionalComma()));
969 
970  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
971  attributes.end());
972  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
973  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
974 
975  return success();
976 }
977 
979  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
980  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
981  p << " [" << attr << "]";
982 }
983 
985  mlir::OperandRange operands, mlir::TypeRange types,
986  std::optional<mlir::ArrayAttr> deviceTypes,
987  std::optional<mlir::DenseI32ArrayAttr> segments) {
988  unsigned opIdx = 0;
989  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
990  p << "{";
991  llvm::interleaveComma(
992  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
993  p << operands[opIdx] << " : " << operands[opIdx].getType();
994  ++opIdx;
995  });
996  p << "}";
997  printSingleDeviceType(p, it.value());
998  });
999 }
1000 
1002  mlir::OpAsmParser &parser,
1004  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1005  mlir::DenseI32ArrayAttr &segments) {
1008 
1009  do {
1010  if (failed(parser.parseLBrace()))
1011  return failure();
1012 
1013  int32_t crtOperandsSize = operands.size();
1014 
1015  if (failed(parser.parseCommaSeparatedList(
1017  if (parser.parseOperand(operands.emplace_back()) ||
1018  parser.parseColonType(types.emplace_back()))
1019  return failure();
1020  return success();
1021  })))
1022  return failure();
1023 
1024  seg.push_back(operands.size() - crtOperandsSize);
1025 
1026  if (failed(parser.parseRBrace()))
1027  return failure();
1028 
1029  if (succeeded(parser.parseOptionalLSquare())) {
1030  if (parser.parseAttribute(attributes.emplace_back()) ||
1031  parser.parseRSquare())
1032  return failure();
1033  } else {
1034  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1036  }
1037  } while (succeeded(parser.parseOptionalComma()));
1038 
1039  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1040  attributes.end());
1041  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1042  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1043 
1044  return success();
1045 }
1046 
1049  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1050  std::optional<mlir::DenseI32ArrayAttr> segments) {
1051  unsigned opIdx = 0;
1052  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1053  p << "{";
1054  llvm::interleaveComma(
1055  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1056  p << operands[opIdx] << " : " << operands[opIdx].getType();
1057  ++opIdx;
1058  });
1059  p << "}";
1060  printSingleDeviceType(p, it.value());
1061  });
1062 }
1063 
1064 static ParseResult parseWaitClause(
1065  mlir::OpAsmParser &parser,
1067  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1068  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1069  mlir::ArrayAttr &keywordOnly) {
1070  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1072 
1073  bool needCommaBeforeOperands = false;
1074 
1075  // Keyword only
1076  if (failed(parser.parseOptionalLParen())) {
1077  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1079  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1080  return success();
1081  }
1082 
1083  // Parse keyword only attributes
1084  if (succeeded(parser.parseOptionalLSquare())) {
1085  if (failed(parser.parseCommaSeparatedList([&]() {
1086  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1087  return failure();
1088  return success();
1089  })))
1090  return failure();
1091  if (parser.parseRSquare())
1092  return failure();
1093  needCommaBeforeOperands = true;
1094  }
1095 
1096  if (needCommaBeforeOperands && failed(parser.parseComma()))
1097  return failure();
1098 
1099  do {
1100  if (failed(parser.parseLBrace()))
1101  return failure();
1102 
1103  int32_t crtOperandsSize = operands.size();
1104 
1105  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1106  if (failed(parser.parseColon()))
1107  return failure();
1108  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1109  } else {
1110  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1111  }
1112 
1113  if (failed(parser.parseCommaSeparatedList(
1115  if (parser.parseOperand(operands.emplace_back()) ||
1116  parser.parseColonType(types.emplace_back()))
1117  return failure();
1118  return success();
1119  })))
1120  return failure();
1121 
1122  seg.push_back(operands.size() - crtOperandsSize);
1123 
1124  if (failed(parser.parseRBrace()))
1125  return failure();
1126 
1127  if (succeeded(parser.parseOptionalLSquare())) {
1128  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1129  parser.parseRSquare())
1130  return failure();
1131  } else {
1132  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1134  }
1135  } while (succeeded(parser.parseOptionalComma()));
1136 
1137  if (failed(parser.parseRParen()))
1138  return failure();
1139 
1140  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1141  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1142  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1143  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1144 
1145  return success();
1146 }
1147 
1148 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1149  if (!hasDeviceTypeValues(attrs))
1150  return false;
1151  if (attrs->size() != 1)
1152  return false;
1153  if (auto deviceTypeAttr =
1154  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1155  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1156  return false;
1157 }
1158 
1160  mlir::OperandRange operands, mlir::TypeRange types,
1161  std::optional<mlir::ArrayAttr> deviceTypes,
1162  std::optional<mlir::DenseI32ArrayAttr> segments,
1163  std::optional<mlir::ArrayAttr> hasDevNum,
1164  std::optional<mlir::ArrayAttr> keywordOnly) {
1165 
1166  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1167  return;
1168 
1169  p << "(";
1170 
1171  printDeviceTypes(p, keywordOnly);
1172  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1173  p << ", ";
1174 
1175  unsigned opIdx = 0;
1176  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1177  p << "{";
1178  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1179  if (boolAttr && boolAttr.getValue())
1180  p << "devnum: ";
1181  llvm::interleaveComma(
1182  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1183  p << operands[opIdx] << " : " << operands[opIdx].getType();
1184  ++opIdx;
1185  });
1186  p << "}";
1187  printSingleDeviceType(p, it.value());
1188  });
1189 
1190  p << ")";
1191 }
1192 
1193 static ParseResult parseDeviceTypeOperands(
1194  mlir::OpAsmParser &parser,
1196  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1198  if (failed(parser.parseCommaSeparatedList([&]() {
1199  if (parser.parseOperand(operands.emplace_back()) ||
1200  parser.parseColonType(types.emplace_back()))
1201  return failure();
1202  if (succeeded(parser.parseOptionalLSquare())) {
1203  if (parser.parseAttribute(attributes.emplace_back()) ||
1204  parser.parseRSquare())
1205  return failure();
1206  } else {
1207  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1208  parser.getContext(), mlir::acc::DeviceType::None));
1209  }
1210  return success();
1211  })))
1212  return failure();
1213  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1214  attributes.end());
1215  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1216  return success();
1217 }
1218 
1219 static void
1221  mlir::OperandRange operands, mlir::TypeRange types,
1222  std::optional<mlir::ArrayAttr> deviceTypes) {
1223  if (!hasDeviceTypeValues(deviceTypes))
1224  return;
1225  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1226  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1227  printSingleDeviceType(p, std::get<0>(it));
1228  });
1229 }
1230 
1232  mlir::OpAsmParser &parser,
1234  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1235  mlir::ArrayAttr &keywordOnlyDeviceType) {
1236 
1237  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1238  bool needCommaBeforeOperands = false;
1239 
1240  if (failed(parser.parseOptionalLParen())) {
1241  // Keyword only
1242  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1244  keywordOnlyDeviceType =
1245  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1246  return success();
1247  }
1248 
1249  // Parse keyword only attributes
1250  if (succeeded(parser.parseOptionalLSquare())) {
1251  // Parse keyword only attributes
1252  if (failed(parser.parseCommaSeparatedList([&]() {
1253  if (parser.parseAttribute(
1254  keywordOnlyDeviceTypeAttributes.emplace_back()))
1255  return failure();
1256  return success();
1257  })))
1258  return failure();
1259  if (parser.parseRSquare())
1260  return failure();
1261  needCommaBeforeOperands = true;
1262  }
1263 
1264  if (needCommaBeforeOperands && failed(parser.parseComma()))
1265  return failure();
1266 
1268  if (failed(parser.parseCommaSeparatedList([&]() {
1269  if (parser.parseOperand(operands.emplace_back()) ||
1270  parser.parseColonType(types.emplace_back()))
1271  return failure();
1272  if (succeeded(parser.parseOptionalLSquare())) {
1273  if (parser.parseAttribute(attributes.emplace_back()) ||
1274  parser.parseRSquare())
1275  return failure();
1276  } else {
1277  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1278  parser.getContext(), mlir::acc::DeviceType::None));
1279  }
1280  return success();
1281  })))
1282  return failure();
1283 
1284  if (failed(parser.parseRParen()))
1285  return failure();
1286 
1287  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1288  attributes.end());
1289  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1290  return success();
1291 }
1292 
1295  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1296  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1297 
1298  if (operands.begin() == operands.end() &&
1299  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1300  return;
1301  }
1302 
1303  p << "(";
1304  printDeviceTypes(p, keywordOnlyDeviceTypes);
1305  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1306  hasDeviceTypeValues(deviceTypes))
1307  p << ", ";
1308  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1309  p << ")";
1310 }
1311 
1312 static ParseResult
1314  mlir::acc::CombinedConstructsTypeAttr &attr) {
1315  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1316  if (parser.parseLParen())
1317  return failure();
1318  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1320  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1321  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1323  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1324  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1326  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1327  } else {
1328  parser.emitError(parser.getCurrentLocation(),
1329  "expected compute construct name");
1330  return failure();
1331  }
1332  if (parser.parseRParen())
1333  return failure();
1334  }
1335  return success();
1336 }
1337 
1338 static void
1340  mlir::acc::CombinedConstructsTypeAttr attr) {
1341  if (attr) {
1342  switch (attr.getValue()) {
1343  case mlir::acc::CombinedConstructsType::KernelsLoop:
1344  p << "combined(kernels)";
1345  break;
1346  case mlir::acc::CombinedConstructsType::ParallelLoop:
1347  p << "combined(parallel)";
1348  break;
1349  case mlir::acc::CombinedConstructsType::SerialLoop:
1350  p << "combined(serial)";
1351  break;
1352  };
1353  }
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // SerialOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 unsigned SerialOp::getNumDataOperands() {
1361  return getReductionOperands().size() + getGangPrivateOperands().size() +
1362  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1363 }
1364 
1365 Value SerialOp::getDataOperand(unsigned i) {
1366  unsigned numOptional = getAsyncOperands().size();
1367  numOptional += getIfCond() ? 1 : 0;
1368  numOptional += getSelfCond() ? 1 : 0;
1369  return getOperand(getWaitOperands().size() + numOptional + i);
1370 }
1371 
1372 bool acc::SerialOp::hasAsyncOnly() {
1373  return hasAsyncOnly(mlir::acc::DeviceType::None);
1374 }
1375 
1376 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1377  return hasDeviceType(getAsyncOnly(), deviceType);
1378 }
1379 
1380 mlir::Value acc::SerialOp::getAsyncValue() {
1381  return getAsyncValue(mlir::acc::DeviceType::None);
1382 }
1383 
1384 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1386  getAsyncOperands(), deviceType);
1387 }
1388 
1389 bool acc::SerialOp::hasWaitOnly() {
1390  return hasWaitOnly(mlir::acc::DeviceType::None);
1391 }
1392 
1393 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1394  return hasDeviceType(getWaitOnly(), deviceType);
1395 }
1396 
1397 mlir::Operation::operand_range SerialOp::getWaitValues() {
1398  return getWaitValues(mlir::acc::DeviceType::None);
1399 }
1400 
1402 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1404  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1405  getHasWaitDevnum(), deviceType);
1406 }
1407 
1408 mlir::Value SerialOp::getWaitDevnum() {
1409  return getWaitDevnum(mlir::acc::DeviceType::None);
1410 }
1411 
1412 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1413  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1414  getWaitOperandsSegments(), getHasWaitDevnum(),
1415  deviceType);
1416 }
1417 
1418 LogicalResult acc::SerialOp::verify() {
1419  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1420  *this, getPrivatizations(), getGangPrivateOperands(), "private",
1421  "privatizations", /*checkOperandType=*/false)))
1422  return failure();
1423  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1424  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1425  "reductions", false)))
1426  return failure();
1427 
1429  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1430  getWaitOperandsDeviceTypeAttr(), "wait")))
1431  return failure();
1432 
1433  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1434  getAsyncOperandsDeviceTypeAttr(),
1435  "async")))
1436  return failure();
1437 
1438  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1439  return failure();
1440 
1441  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1442 }
1443 
1444 //===----------------------------------------------------------------------===//
1445 // KernelsOp
1446 //===----------------------------------------------------------------------===//
1447 
1448 unsigned KernelsOp::getNumDataOperands() {
1449  return getDataClauseOperands().size();
1450 }
1451 
1452 Value KernelsOp::getDataOperand(unsigned i) {
1453  unsigned numOptional = getAsyncOperands().size();
1454  numOptional += getWaitOperands().size();
1455  numOptional += getNumGangs().size();
1456  numOptional += getNumWorkers().size();
1457  numOptional += getVectorLength().size();
1458  numOptional += getIfCond() ? 1 : 0;
1459  numOptional += getSelfCond() ? 1 : 0;
1460  return getOperand(numOptional + i);
1461 }
1462 
1463 bool acc::KernelsOp::hasAsyncOnly() {
1464  return hasAsyncOnly(mlir::acc::DeviceType::None);
1465 }
1466 
1467 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1468  return hasDeviceType(getAsyncOnly(), deviceType);
1469 }
1470 
1471 mlir::Value acc::KernelsOp::getAsyncValue() {
1472  return getAsyncValue(mlir::acc::DeviceType::None);
1473 }
1474 
1475 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1477  getAsyncOperands(), deviceType);
1478 }
1479 
1480 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1481  return getNumWorkersValue(mlir::acc::DeviceType::None);
1482 }
1483 
1485 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1486  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1487  deviceType);
1488 }
1489 
1490 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1491  return getVectorLengthValue(mlir::acc::DeviceType::None);
1492 }
1493 
1495 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1496  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1497  getVectorLength(), deviceType);
1498 }
1499 
1500 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1501  return getNumGangsValues(mlir::acc::DeviceType::None);
1502 }
1503 
1505 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1506  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1507  getNumGangsSegments(), deviceType);
1508 }
1509 
1510 bool acc::KernelsOp::hasWaitOnly() {
1511  return hasWaitOnly(mlir::acc::DeviceType::None);
1512 }
1513 
1514 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1515  return hasDeviceType(getWaitOnly(), deviceType);
1516 }
1517 
1518 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1519  return getWaitValues(mlir::acc::DeviceType::None);
1520 }
1521 
1523 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1525  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1526  getHasWaitDevnum(), deviceType);
1527 }
1528 
1529 mlir::Value KernelsOp::getWaitDevnum() {
1530  return getWaitDevnum(mlir::acc::DeviceType::None);
1531 }
1532 
1533 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1534  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1535  getWaitOperandsSegments(), getHasWaitDevnum(),
1536  deviceType);
1537 }
1538 
1539 LogicalResult acc::KernelsOp::verify() {
1541  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1542  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1543  return failure();
1544 
1546  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1547  getWaitOperandsDeviceTypeAttr(), "wait")))
1548  return failure();
1549 
1550  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1551  getNumWorkersDeviceTypeAttr(),
1552  "num_workers")))
1553  return failure();
1554 
1555  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1556  getVectorLengthDeviceTypeAttr(),
1557  "vector_length")))
1558  return failure();
1559 
1560  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1561  getAsyncOperandsDeviceTypeAttr(),
1562  "async")))
1563  return failure();
1564 
1565  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1566  return failure();
1567 
1568  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // HostDataOp
1573 //===----------------------------------------------------------------------===//
1574 
1575 LogicalResult acc::HostDataOp::verify() {
1576  if (getDataClauseOperands().empty())
1577  return emitError("at least one operand must appear on the host_data "
1578  "operation");
1579 
1580  for (mlir::Value operand : getDataClauseOperands())
1581  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1582  return emitError("expect data entry operation as defining op");
1583  return success();
1584 }
1585 
1586 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1587  MLIRContext *context) {
1588  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1589 }
1590 
1591 //===----------------------------------------------------------------------===//
1592 // LoopOp
1593 //===----------------------------------------------------------------------===//
1594 
1595 static ParseResult parseGangValue(
1596  OpAsmParser &parser, llvm::StringRef keyword,
1599  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1600  bool &needCommaBetweenValues, bool &newValue) {
1601  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1602  if (parser.parseEqual())
1603  return failure();
1604  if (parser.parseOperand(operands.emplace_back()) ||
1605  parser.parseColonType(types.emplace_back()))
1606  return failure();
1607  attributes.push_back(gangArgType);
1608  needCommaBetweenValues = true;
1609  newValue = true;
1610  }
1611  return success();
1612 }
1613 
1614 static ParseResult parseGangClause(
1615  OpAsmParser &parser,
1617  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1618  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1619  mlir::ArrayAttr &gangOnlyDeviceType) {
1620  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1621  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1622  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1624  bool needCommaBetweenValues = false;
1625  bool needCommaBeforeOperands = false;
1626 
1627  if (failed(parser.parseOptionalLParen())) {
1628  // Gang only keyword
1629  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1631  gangOnlyDeviceType =
1632  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1633  return success();
1634  }
1635 
1636  // Parse gang only attributes
1637  if (succeeded(parser.parseOptionalLSquare())) {
1638  // Parse gang only attributes
1639  if (failed(parser.parseCommaSeparatedList([&]() {
1640  if (parser.parseAttribute(
1641  gangOnlyDeviceTypeAttributes.emplace_back()))
1642  return failure();
1643  return success();
1644  })))
1645  return failure();
1646  if (parser.parseRSquare())
1647  return failure();
1648  needCommaBeforeOperands = true;
1649  }
1650 
1651  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1652  mlir::acc::GangArgType::Num);
1653  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1654  mlir::acc::GangArgType::Dim);
1655  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1656  parser.getContext(), mlir::acc::GangArgType::Static);
1657 
1658  do {
1659  if (needCommaBeforeOperands) {
1660  needCommaBeforeOperands = false;
1661  continue;
1662  }
1663 
1664  if (failed(parser.parseLBrace()))
1665  return failure();
1666 
1667  int32_t crtOperandsSize = gangOperands.size();
1668  while (true) {
1669  bool newValue = false;
1670  bool needValue = false;
1671  if (needCommaBetweenValues) {
1672  if (succeeded(parser.parseOptionalComma()))
1673  needValue = true; // expect a new value after comma.
1674  else
1675  break;
1676  }
1677 
1678  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1679  gangOperands, gangOperandsType,
1680  gangArgTypeAttributes, argNum,
1681  needCommaBetweenValues, newValue)))
1682  return failure();
1683  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1684  gangOperands, gangOperandsType,
1685  gangArgTypeAttributes, argDim,
1686  needCommaBetweenValues, newValue)))
1687  return failure();
1688  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1689  gangOperands, gangOperandsType,
1690  gangArgTypeAttributes, argStatic,
1691  needCommaBetweenValues, newValue)))
1692  return failure();
1693 
1694  if (!newValue && needValue) {
1695  parser.emitError(parser.getCurrentLocation(),
1696  "new value expected after comma");
1697  return failure();
1698  }
1699 
1700  if (!newValue)
1701  break;
1702  }
1703 
1704  if (gangOperands.empty())
1705  return parser.emitError(
1706  parser.getCurrentLocation(),
1707  "expect at least one of num, dim or static values");
1708 
1709  if (failed(parser.parseRBrace()))
1710  return failure();
1711 
1712  if (succeeded(parser.parseOptionalLSquare())) {
1713  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1714  parser.parseRSquare())
1715  return failure();
1716  } else {
1717  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1719  }
1720 
1721  seg.push_back(gangOperands.size() - crtOperandsSize);
1722 
1723  } while (succeeded(parser.parseOptionalComma()));
1724 
1725  if (failed(parser.parseRParen()))
1726  return failure();
1727 
1728  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1729  gangArgTypeAttributes.end());
1730  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1731  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1732 
1734  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1735  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1736 
1737  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1738  return success();
1739 }
1740 
1742  mlir::OperandRange operands, mlir::TypeRange types,
1743  std::optional<mlir::ArrayAttr> gangArgTypes,
1744  std::optional<mlir::ArrayAttr> deviceTypes,
1745  std::optional<mlir::DenseI32ArrayAttr> segments,
1746  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1747 
1748  if (operands.begin() == operands.end() &&
1749  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1750  return;
1751  }
1752 
1753  p << "(";
1754 
1755  printDeviceTypes(p, gangOnlyDeviceTypes);
1756 
1757  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1758  hasDeviceTypeValues(deviceTypes))
1759  p << ", ";
1760 
1761  if (hasDeviceTypeValues(deviceTypes)) {
1762  unsigned opIdx = 0;
1763  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1764  p << "{";
1765  llvm::interleaveComma(
1766  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1767  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1768  (*gangArgTypes)[opIdx]);
1769  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1770  p << LoopOp::getGangNumKeyword();
1771  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1772  p << LoopOp::getGangDimKeyword();
1773  else if (gangArgTypeAttr.getValue() ==
1774  mlir::acc::GangArgType::Static)
1775  p << LoopOp::getGangStaticKeyword();
1776  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1777  ++opIdx;
1778  });
1779  p << "}";
1780  printSingleDeviceType(p, it.value());
1781  });
1782  }
1783  p << ")";
1784 }
1785 
1787  std::optional<mlir::ArrayAttr> segments,
1788  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1789  if (!segments)
1790  return false;
1791  for (auto attr : *segments) {
1792  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1793  if (deviceTypes.contains(deviceTypeAttr.getValue()))
1794  return true;
1795  deviceTypes.insert(deviceTypeAttr.getValue());
1796  }
1797  return false;
1798 }
1799 
1800 /// Check for duplicates in the DeviceType array attribute.
1801 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1802  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1803  if (!deviceTypes)
1804  return success();
1805  for (auto attr : deviceTypes) {
1806  auto deviceTypeAttr =
1807  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1808  if (!deviceTypeAttr)
1809  return failure();
1810  if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1811  return failure();
1812  crtDeviceTypes.insert(deviceTypeAttr.getValue());
1813  }
1814  return success();
1815 }
1816 
1817 LogicalResult acc::LoopOp::verify() {
1818  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1819  (getUpperbound().size() != getInclusiveUpperbound()->size()))
1820  return emitError() << "inclusiveUpperbound size is expected to be the same"
1821  << " as upperbound size";
1822 
1823  // Check collapse
1824  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1825  return emitOpError() << "collapse device_type attr must be define when"
1826  << " collapse attr is present";
1827 
1828  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1829  getCollapseAttr().getValue().size() !=
1830  getCollapseDeviceTypeAttr().getValue().size())
1831  return emitOpError() << "collapse attribute count must match collapse"
1832  << " device_type count";
1833  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1834  return emitOpError()
1835  << "duplicate device_type found in collapseDeviceType attribute";
1836 
1837  // Check gang
1838  if (!getGangOperands().empty()) {
1839  if (!getGangOperandsArgType())
1840  return emitOpError() << "gangOperandsArgType attribute must be defined"
1841  << " when gang operands are present";
1842 
1843  if (getGangOperands().size() !=
1844  getGangOperandsArgTypeAttr().getValue().size())
1845  return emitOpError() << "gangOperandsArgType attribute count must match"
1846  << " gangOperands count";
1847  }
1848  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1849  return emitOpError() << "duplicate device_type found in gang attribute";
1850 
1852  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1853  getGangOperandsDeviceTypeAttr(), "gang")))
1854  return failure();
1855 
1856  // Check worker
1857  if (failed(checkDeviceTypes(getWorkerAttr())))
1858  return emitOpError() << "duplicate device_type found in worker attribute";
1859  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1860  return emitOpError() << "duplicate device_type found in "
1861  "workerNumOperandsDeviceType attribute";
1862  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1863  getWorkerNumOperandsDeviceTypeAttr(),
1864  "worker")))
1865  return failure();
1866 
1867  // Check vector
1868  if (failed(checkDeviceTypes(getVectorAttr())))
1869  return emitOpError() << "duplicate device_type found in vector attribute";
1870  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1871  return emitOpError() << "duplicate device_type found in "
1872  "vectorOperandsDeviceType attribute";
1873  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1874  getVectorOperandsDeviceTypeAttr(),
1875  "vector")))
1876  return failure();
1877 
1879  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1880  getTileOperandsDeviceTypeAttr(), "tile")))
1881  return failure();
1882 
1883  // auto, independent and seq attribute are mutually exclusive.
1884  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1885  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1886  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1887  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1888  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1889  << "\", " << getIndependentAttrName() << ", "
1890  << getSeqAttrName()
1891  << " can be present at the same time";
1892  }
1893 
1894  // Gang, worker and vector are incompatible with seq.
1895  if (getSeqAttr()) {
1896  for (auto attr : getSeqAttr()) {
1897  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1898  if (hasVector(deviceTypeAttr.getValue()) ||
1899  getVectorValue(deviceTypeAttr.getValue()) ||
1900  hasWorker(deviceTypeAttr.getValue()) ||
1901  getWorkerValue(deviceTypeAttr.getValue()) ||
1902  hasGang(deviceTypeAttr.getValue()) ||
1903  getGangValue(mlir::acc::GangArgType::Num,
1904  deviceTypeAttr.getValue()) ||
1905  getGangValue(mlir::acc::GangArgType::Dim,
1906  deviceTypeAttr.getValue()) ||
1907  getGangValue(mlir::acc::GangArgType::Static,
1908  deviceTypeAttr.getValue()))
1909  return emitError()
1910  << "gang, worker or vector cannot appear with the seq attr";
1911  }
1912  }
1913 
1914  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1915  *this, getPrivatizations(), getPrivateOperands(), "private",
1916  "privatizations", false)))
1917  return failure();
1918 
1919  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1920  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1921  "reductions", false)))
1922  return failure();
1923 
1924  if (getCombined().has_value() &&
1925  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1926  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1927  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1928  return emitError("unexpected combined constructs attribute");
1929  }
1930 
1931  // Check non-empty body().
1932  if (getRegion().empty())
1933  return emitError("expected non-empty body.");
1934 
1935  return success();
1936 }
1937 
1938 unsigned LoopOp::getNumDataOperands() {
1939  return getReductionOperands().size() + getPrivateOperands().size();
1940 }
1941 
1942 Value LoopOp::getDataOperand(unsigned i) {
1943  unsigned numOptional =
1944  getLowerbound().size() + getUpperbound().size() + getStep().size();
1945  numOptional += getGangOperands().size();
1946  numOptional += getVectorOperands().size();
1947  numOptional += getWorkerNumOperands().size();
1948  numOptional += getTileOperands().size();
1949  numOptional += getCacheOperands().size();
1950  return getOperand(numOptional + i);
1951 }
1952 
1953 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
1954 
1955 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1956  return hasDeviceType(getAuto_(), deviceType);
1957 }
1958 
1959 bool LoopOp::hasIndependent() {
1960  return hasIndependent(mlir::acc::DeviceType::None);
1961 }
1962 
1963 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1964  return hasDeviceType(getIndependent(), deviceType);
1965 }
1966 
1967 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
1968 
1969 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1970  return hasDeviceType(getSeq(), deviceType);
1971 }
1972 
1973 mlir::Value LoopOp::getVectorValue() {
1974  return getVectorValue(mlir::acc::DeviceType::None);
1975 }
1976 
1977 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1978  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
1979  getVectorOperands(), deviceType);
1980 }
1981 
1982 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
1983 
1984 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1985  return hasDeviceType(getVector(), deviceType);
1986 }
1987 
1988 mlir::Value LoopOp::getWorkerValue() {
1989  return getWorkerValue(mlir::acc::DeviceType::None);
1990 }
1991 
1992 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1993  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
1994  getWorkerNumOperands(), deviceType);
1995 }
1996 
1997 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
1998 
1999 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2000  return hasDeviceType(getWorker(), deviceType);
2001 }
2002 
2003 mlir::Operation::operand_range LoopOp::getTileValues() {
2004  return getTileValues(mlir::acc::DeviceType::None);
2005 }
2006 
2008 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2009  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2010  getTileOperandsSegments(), deviceType);
2011 }
2012 
2013 std::optional<int64_t> LoopOp::getCollapseValue() {
2014  return getCollapseValue(mlir::acc::DeviceType::None);
2015 }
2016 
2017 std::optional<int64_t>
2018 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2019  if (!getCollapseAttr())
2020  return std::nullopt;
2021  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2022  auto intAttr =
2023  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2024  return intAttr.getValue().getZExtValue();
2025  }
2026  return std::nullopt;
2027 }
2028 
2029 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2030  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2031 }
2032 
2033 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2034  mlir::acc::DeviceType deviceType) {
2035  if (getGangOperands().empty())
2036  return {};
2037  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2038  int32_t nbOperandsBefore = 0;
2039  for (unsigned i = 0; i < *pos; ++i)
2040  nbOperandsBefore += (*getGangOperandsSegments())[i];
2042  getGangOperands()
2043  .drop_front(nbOperandsBefore)
2044  .take_front((*getGangOperandsSegments())[*pos]);
2045 
2046  int32_t argTypeIdx = nbOperandsBefore;
2047  for (auto value : values) {
2048  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2049  (*getGangOperandsArgType())[argTypeIdx]);
2050  if (gangArgTypeAttr.getValue() == gangArgType)
2051  return value;
2052  ++argTypeIdx;
2053  }
2054  }
2055  return {};
2056 }
2057 
2058 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2059 
2060 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2061  return hasDeviceType(getGang(), deviceType);
2062 }
2063 
2064 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2065  return {&getRegion()};
2066 }
2067 
2068 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2069 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2070 /// `(` ssa-id-and-type-list `)`
2071 /// region
2072 ParseResult
2075  SmallVectorImpl<Type> &lowerboundType,
2077  SmallVectorImpl<Type> &upperboundType,
2079  SmallVectorImpl<Type> &stepType) {
2080 
2081  SmallVector<OpAsmParser::Argument> inductionVars;
2082  if (succeeded(
2083  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2084  if (parser.parseLParen() ||
2085  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2086  /*allowType=*/true) ||
2087  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2088  parser.parseOperandList(lowerbound, inductionVars.size(),
2090  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2091  parser.parseKeyword("to") || parser.parseLParen() ||
2092  parser.parseOperandList(upperbound, inductionVars.size(),
2094  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2095  parser.parseKeyword("step") || parser.parseLParen() ||
2096  parser.parseOperandList(step, inductionVars.size(),
2098  parser.parseColonTypeList(stepType) || parser.parseRParen())
2099  return failure();
2100  }
2101  return parser.parseRegion(region, inductionVars);
2102 }
2103 
2105  ValueRange lowerbound, TypeRange lowerboundType,
2106  ValueRange upperbound, TypeRange upperboundType,
2107  ValueRange steps, TypeRange stepType) {
2108  ValueRange regionArgs = region.front().getArguments();
2109  if (!regionArgs.empty()) {
2110  p << acc::LoopOp::getControlKeyword() << "(";
2111  llvm::interleaveComma(regionArgs, p,
2112  [&p](Value v) { p << v << " : " << v.getType(); });
2113  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2114  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2115  << " : " << stepType << ") ";
2116  }
2117  p.printRegion(region, /*printEntryBlockArgs=*/false);
2118 }
2119 
2120 //===----------------------------------------------------------------------===//
2121 // DataOp
2122 //===----------------------------------------------------------------------===//
2123 
2124 LogicalResult acc::DataOp::verify() {
2125  // 2.6.5. Data Construct restriction
2126  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2127  // attach, or default clause must appear on a data construct.
2128  if (getOperands().empty() && !getDefaultAttr())
2129  return emitError("at least one operand or the default attribute "
2130  "must appear on the data operation");
2131 
2132  for (mlir::Value operand : getDataClauseOperands())
2133  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2134  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2135  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2136  operand.getDefiningOp()))
2137  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2138  "as defining op");
2139 
2140  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2141  return failure();
2142 
2143  return success();
2144 }
2145 
2146 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2147 
2148 Value DataOp::getDataOperand(unsigned i) {
2149  unsigned numOptional = getIfCond() ? 1 : 0;
2150  numOptional += getAsyncOperands().size() ? 1 : 0;
2151  numOptional += getWaitOperands().size();
2152  return getOperand(numOptional + i);
2153 }
2154 
2155 bool acc::DataOp::hasAsyncOnly() {
2156  return hasAsyncOnly(mlir::acc::DeviceType::None);
2157 }
2158 
2159 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2160  return hasDeviceType(getAsyncOnly(), deviceType);
2161 }
2162 
2163 mlir::Value DataOp::getAsyncValue() {
2164  return getAsyncValue(mlir::acc::DeviceType::None);
2165 }
2166 
2167 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2169  getAsyncOperands(), deviceType);
2170 }
2171 
2172 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2173 
2174 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2175  return hasDeviceType(getWaitOnly(), deviceType);
2176 }
2177 
2178 mlir::Operation::operand_range DataOp::getWaitValues() {
2179  return getWaitValues(mlir::acc::DeviceType::None);
2180 }
2181 
2183 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2185  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2186  getHasWaitDevnum(), deviceType);
2187 }
2188 
2189 mlir::Value DataOp::getWaitDevnum() {
2190  return getWaitDevnum(mlir::acc::DeviceType::None);
2191 }
2192 
2193 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2194  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2195  getWaitOperandsSegments(), getHasWaitDevnum(),
2196  deviceType);
2197 }
2198 
2199 //===----------------------------------------------------------------------===//
2200 // ExitDataOp
2201 //===----------------------------------------------------------------------===//
2202 
2203 LogicalResult acc::ExitDataOp::verify() {
2204  // 2.6.6. Data Exit Directive restriction
2205  // At least one copyout, delete, or detach clause must appear on an exit data
2206  // directive.
2207  if (getDataClauseOperands().empty())
2208  return emitError("at least one operand must be present in dataOperands on "
2209  "the exit data operation");
2210 
2211  // The async attribute represent the async clause without value. Therefore the
2212  // attribute and operand cannot appear at the same time.
2213  if (getAsyncOperand() && getAsync())
2214  return emitError("async attribute cannot appear with asyncOperand");
2215 
2216  // The wait attribute represent the wait clause without values. Therefore the
2217  // attribute and operands cannot appear at the same time.
2218  if (!getWaitOperands().empty() && getWait())
2219  return emitError("wait attribute cannot appear with waitOperands");
2220 
2221  if (getWaitDevnum() && getWaitOperands().empty())
2222  return emitError("wait_devnum cannot appear without waitOperands");
2223 
2224  return success();
2225 }
2226 
2227 unsigned ExitDataOp::getNumDataOperands() {
2228  return getDataClauseOperands().size();
2229 }
2230 
2231 Value ExitDataOp::getDataOperand(unsigned i) {
2232  unsigned numOptional = getIfCond() ? 1 : 0;
2233  numOptional += getAsyncOperand() ? 1 : 0;
2234  numOptional += getWaitDevnum() ? 1 : 0;
2235  return getOperand(getWaitOperands().size() + numOptional + i);
2236 }
2237 
2238 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2239  MLIRContext *context) {
2240  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // EnterDataOp
2245 //===----------------------------------------------------------------------===//
2246 
2247 LogicalResult acc::EnterDataOp::verify() {
2248  // 2.6.6. Data Enter Directive restriction
2249  // At least one copyin, create, or attach clause must appear on an enter data
2250  // directive.
2251  if (getDataClauseOperands().empty())
2252  return emitError("at least one operand must be present in dataOperands on "
2253  "the enter data operation");
2254 
2255  // The async attribute represent the async clause without value. Therefore the
2256  // attribute and operand cannot appear at the same time.
2257  if (getAsyncOperand() && getAsync())
2258  return emitError("async attribute cannot appear with asyncOperand");
2259 
2260  // The wait attribute represent the wait clause without values. Therefore the
2261  // attribute and operands cannot appear at the same time.
2262  if (!getWaitOperands().empty() && getWait())
2263  return emitError("wait attribute cannot appear with waitOperands");
2264 
2265  if (getWaitDevnum() && getWaitOperands().empty())
2266  return emitError("wait_devnum cannot appear without waitOperands");
2267 
2268  for (mlir::Value operand : getDataClauseOperands())
2269  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2270  operand.getDefiningOp()))
2271  return emitError("expect data entry operation as defining op");
2272 
2273  return success();
2274 }
2275 
2276 unsigned EnterDataOp::getNumDataOperands() {
2277  return getDataClauseOperands().size();
2278 }
2279 
2280 Value EnterDataOp::getDataOperand(unsigned i) {
2281  unsigned numOptional = getIfCond() ? 1 : 0;
2282  numOptional += getAsyncOperand() ? 1 : 0;
2283  numOptional += getWaitDevnum() ? 1 : 0;
2284  return getOperand(getWaitOperands().size() + numOptional + i);
2285 }
2286 
2287 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2288  MLIRContext *context) {
2289  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2290 }
2291 
2292 //===----------------------------------------------------------------------===//
2293 // AtomicReadOp
2294 //===----------------------------------------------------------------------===//
2295 
2296 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2297 
2298 //===----------------------------------------------------------------------===//
2299 // AtomicWriteOp
2300 //===----------------------------------------------------------------------===//
2301 
2302 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2303 
2304 //===----------------------------------------------------------------------===//
2305 // AtomicUpdateOp
2306 //===----------------------------------------------------------------------===//
2307 
2308 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2309  PatternRewriter &rewriter) {
2310  if (op.isNoOp()) {
2311  rewriter.eraseOp(op);
2312  return success();
2313  }
2314 
2315  if (Value writeVal = op.getWriteOpVal()) {
2316  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2317  return success();
2318  }
2319 
2320  return failure();
2321 }
2322 
2323 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2324 
2325 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2326 
2327 //===----------------------------------------------------------------------===//
2328 // AtomicCaptureOp
2329 //===----------------------------------------------------------------------===//
2330 
2331 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2332  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2333  return op;
2334  return dyn_cast<AtomicReadOp>(getSecondOp());
2335 }
2336 
2337 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2338  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2339  return op;
2340  return dyn_cast<AtomicWriteOp>(getSecondOp());
2341 }
2342 
2343 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2344  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2345  return op;
2346  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2347 }
2348 
2349 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2350 
2351 //===----------------------------------------------------------------------===//
2352 // DeclareEnterOp
2353 //===----------------------------------------------------------------------===//
2354 
2355 template <typename Op>
2356 static LogicalResult
2358  bool requireAtLeastOneOperand = true) {
2359  if (operands.empty() && requireAtLeastOneOperand)
2360  return emitError(
2361  op->getLoc(),
2362  "at least one operand must appear on the declare operation");
2363 
2364  for (mlir::Value operand : operands) {
2365  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2366  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2367  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2368  operand.getDefiningOp()))
2369  return op.emitError(
2370  "expect valid declare data entry operation or acc.getdeviceptr "
2371  "as defining op");
2372 
2373  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2374  assert(varPtr && "declare operands can only be data entry operations which "
2375  "must have varPtr");
2376  std::optional<mlir::acc::DataClause> dataClauseOptional{
2377  getDataClause(operand.getDefiningOp())};
2378  assert(dataClauseOptional.has_value() &&
2379  "declare operands can only be data entry operations which must have "
2380  "dataClause");
2381 
2382  // If varPtr has no defining op - there is nothing to check further.
2383  if (!varPtr.getDefiningOp())
2384  continue;
2385 
2386  // Check that the varPtr has a declare attribute.
2387  auto declareAttribute{
2388  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2389  if (!declareAttribute)
2390  return op.emitError(
2391  "expect declare attribute on variable in declare operation");
2392 
2393  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2394  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2395  return op.emitError(
2396  "expect matching declare attribute on variable in declare operation");
2397 
2398  // If the variable is marked with implicit attribute, the matching declare
2399  // data action must also be marked implicit. The reverse is not checked
2400  // since implicit data action may be inserted to do actions like updating
2401  // device copy, in which case the variable is not necessarily implicitly
2402  // declare'd.
2403  if (declAttr.getImplicit() &&
2404  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2405  return op.emitError(
2406  "implicitness must match between declare op and flag on variable");
2407  }
2408 
2409  return success();
2410 }
2411 
2412 LogicalResult acc::DeclareEnterOp::verify() {
2413  return checkDeclareOperands(*this, this->getDataClauseOperands());
2414 }
2415 
2416 //===----------------------------------------------------------------------===//
2417 // DeclareExitOp
2418 //===----------------------------------------------------------------------===//
2419 
2420 LogicalResult acc::DeclareExitOp::verify() {
2421  if (getToken())
2422  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2423  /*requireAtLeastOneOperand=*/false);
2424  return checkDeclareOperands(*this, this->getDataClauseOperands());
2425 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // DeclareOp
2429 //===----------------------------------------------------------------------===//
2430 
2431 LogicalResult acc::DeclareOp::verify() {
2432  return checkDeclareOperands(*this, this->getDataClauseOperands());
2433 }
2434 
2435 //===----------------------------------------------------------------------===//
2436 // RoutineOp
2437 //===----------------------------------------------------------------------===//
2438 
2439 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2440  acc::DeviceType dtype) {
2441  unsigned parallelism = 0;
2442  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2443  parallelism += op.hasWorker(dtype) ? 1 : 0;
2444  parallelism += op.hasVector(dtype) ? 1 : 0;
2445  parallelism += op.hasSeq(dtype) ? 1 : 0;
2446  return parallelism;
2447 }
2448 
2449 LogicalResult acc::RoutineOp::verify() {
2450  unsigned baseParallelism =
2452 
2453  if (baseParallelism > 1)
2454  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2455  "be present at the same time";
2456 
2457  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2458  ++dtypeInt) {
2459  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2460  if (dtype == acc::DeviceType::None)
2461  continue;
2462  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2463 
2464  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2465  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2466  "be present at the same time";
2467  }
2468 
2469  return success();
2470 }
2471 
2472 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2473  mlir::ArrayAttr &deviceTypes) {
2474  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2475  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2476 
2477  if (failed(parser.parseCommaSeparatedList([&]() {
2478  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2479  return failure();
2480  if (failed(parser.parseOptionalLSquare())) {
2481  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2482  parser.getContext(), mlir::acc::DeviceType::None));
2483  } else {
2484  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2485  parser.parseRSquare())
2486  return failure();
2487  }
2488  return success();
2489  })))
2490  return failure();
2491 
2492  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2493  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2494 
2495  return success();
2496 }
2497 
2499  std::optional<mlir::ArrayAttr> bindName,
2500  std::optional<mlir::ArrayAttr> deviceTypes) {
2501  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2502  [&](const auto &pair) {
2503  p << std::get<0>(pair);
2504  printSingleDeviceType(p, std::get<1>(pair));
2505  });
2506 }
2507 
2508 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2509  mlir::ArrayAttr &gang,
2510  mlir::ArrayAttr &gangDim,
2511  mlir::ArrayAttr &gangDimDeviceTypes) {
2512 
2513  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2514  gangDimDeviceTypeAttrs;
2515  bool needCommaBeforeOperands = false;
2516 
2517  // Gang keyword only
2518  if (failed(parser.parseOptionalLParen())) {
2519  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2521  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2522  return success();
2523  }
2524 
2525  // Parse keyword only attributes
2526  if (succeeded(parser.parseOptionalLSquare())) {
2527  if (failed(parser.parseCommaSeparatedList([&]() {
2528  if (parser.parseAttribute(gangAttrs.emplace_back()))
2529  return failure();
2530  return success();
2531  })))
2532  return failure();
2533  if (parser.parseRSquare())
2534  return failure();
2535  needCommaBeforeOperands = true;
2536  }
2537 
2538  if (needCommaBeforeOperands && failed(parser.parseComma()))
2539  return failure();
2540 
2541  if (failed(parser.parseCommaSeparatedList([&]() {
2542  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2543  parser.parseColon() ||
2544  parser.parseAttribute(gangDimAttrs.emplace_back()))
2545  return failure();
2546  if (succeeded(parser.parseOptionalLSquare())) {
2547  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2548  parser.parseRSquare())
2549  return failure();
2550  } else {
2551  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2552  parser.getContext(), mlir::acc::DeviceType::None));
2553  }
2554  return success();
2555  })))
2556  return failure();
2557 
2558  if (failed(parser.parseRParen()))
2559  return failure();
2560 
2561  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2562  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2563  gangDimDeviceTypes =
2564  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2565 
2566  return success();
2567 }
2568 
2570  std::optional<mlir::ArrayAttr> gang,
2571  std::optional<mlir::ArrayAttr> gangDim,
2572  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2573 
2574  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2575  gang->size() == 1) {
2576  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2577  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2578  return;
2579  }
2580 
2581  p << "(";
2582 
2583  printDeviceTypes(p, gang);
2584 
2585  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2586  p << ", ";
2587 
2588  if (hasDeviceTypeValues(gangDimDeviceTypes))
2589  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2590  [&](const auto &pair) {
2591  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2592  p << std::get<0>(pair);
2593  printSingleDeviceType(p, std::get<1>(pair));
2594  });
2595 
2596  p << ")";
2597 }
2598 
2599 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2600  mlir::ArrayAttr &deviceTypes) {
2602  // Keyword only
2603  if (failed(parser.parseOptionalLParen())) {
2604  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2606  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2607  return success();
2608  }
2609 
2610  // Parse device type attributes
2611  if (succeeded(parser.parseOptionalLSquare())) {
2612  if (failed(parser.parseCommaSeparatedList([&]() {
2613  if (parser.parseAttribute(attributes.emplace_back()))
2614  return failure();
2615  return success();
2616  })))
2617  return failure();
2618  if (parser.parseRSquare() || parser.parseRParen())
2619  return failure();
2620  }
2621  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2622  return success();
2623 }
2624 
2625 static void
2627  std::optional<mlir::ArrayAttr> deviceTypes) {
2628 
2629  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2630  auto deviceTypeAttr =
2631  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2632  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2633  return;
2634  }
2635 
2636  if (!hasDeviceTypeValues(deviceTypes))
2637  return;
2638 
2639  p << "([";
2640  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2641  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2642  p << dTypeAttr;
2643  });
2644  p << "])";
2645 }
2646 
2647 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2648 
2649 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2650  return hasDeviceType(getWorker(), deviceType);
2651 }
2652 
2653 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2654 
2655 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2656  return hasDeviceType(getVector(), deviceType);
2657 }
2658 
2659 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2660 
2661 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2662  return hasDeviceType(getSeq(), deviceType);
2663 }
2664 
2665 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2666  return getBindNameValue(mlir::acc::DeviceType::None);
2667 }
2668 
2669 std::optional<llvm::StringRef>
2670 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2671  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2672  return std::nullopt;
2673  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2674  auto attr = (*getBindName())[*pos];
2675  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2676  return stringAttr.getValue();
2677  }
2678  return std::nullopt;
2679 }
2680 
2681 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2682 
2683 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2684  return hasDeviceType(getGang(), deviceType);
2685 }
2686 
2687 std::optional<int64_t> RoutineOp::getGangDimValue() {
2688  return getGangDimValue(mlir::acc::DeviceType::None);
2689 }
2690 
2691 std::optional<int64_t>
2692 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2693  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2694  return std::nullopt;
2695  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2696  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2697  return intAttr.getInt();
2698  }
2699  return std::nullopt;
2700 }
2701 
2702 //===----------------------------------------------------------------------===//
2703 // InitOp
2704 //===----------------------------------------------------------------------===//
2705 
2706 LogicalResult acc::InitOp::verify() {
2707  Operation *currOp = *this;
2708  while ((currOp = currOp->getParentOp()))
2709  if (isComputeOperation(currOp))
2710  return emitOpError("cannot be nested in a compute operation");
2711  return success();
2712 }
2713 
2714 //===----------------------------------------------------------------------===//
2715 // ShutdownOp
2716 //===----------------------------------------------------------------------===//
2717 
2718 LogicalResult acc::ShutdownOp::verify() {
2719  Operation *currOp = *this;
2720  while ((currOp = currOp->getParentOp()))
2721  if (isComputeOperation(currOp))
2722  return emitOpError("cannot be nested in a compute operation");
2723  return success();
2724 }
2725 
2726 //===----------------------------------------------------------------------===//
2727 // SetOp
2728 //===----------------------------------------------------------------------===//
2729 
2730 LogicalResult acc::SetOp::verify() {
2731  Operation *currOp = *this;
2732  while ((currOp = currOp->getParentOp()))
2733  if (isComputeOperation(currOp))
2734  return emitOpError("cannot be nested in a compute operation");
2735  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2736  return emitOpError("at least one default_async, device_num, or device_type "
2737  "operand must appear");
2738  return success();
2739 }
2740 
2741 //===----------------------------------------------------------------------===//
2742 // UpdateOp
2743 //===----------------------------------------------------------------------===//
2744 
2745 LogicalResult acc::UpdateOp::verify() {
2746  // At least one of host or device should have a value.
2747  if (getDataClauseOperands().empty())
2748  return emitError("at least one value must be present in dataOperands");
2749 
2750  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2751  getAsyncOperandsDeviceTypeAttr(),
2752  "async")))
2753  return failure();
2754 
2756  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2757  getWaitOperandsDeviceTypeAttr(), "wait")))
2758  return failure();
2759 
2760  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2761  return failure();
2762 
2763  for (mlir::Value operand : getDataClauseOperands())
2764  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2765  operand.getDefiningOp()))
2766  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2767  "as defining op");
2768 
2769  return success();
2770 }
2771 
2772 unsigned UpdateOp::getNumDataOperands() {
2773  return getDataClauseOperands().size();
2774 }
2775 
2776 Value UpdateOp::getDataOperand(unsigned i) {
2777  unsigned numOptional = getAsyncOperands().size();
2778  numOptional += getIfCond() ? 1 : 0;
2779  return getOperand(getWaitOperands().size() + numOptional + i);
2780 }
2781 
2782 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2783  MLIRContext *context) {
2784  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2785 }
2786 
2787 bool UpdateOp::hasAsyncOnly() {
2788  return hasAsyncOnly(mlir::acc::DeviceType::None);
2789 }
2790 
2791 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2792  return hasDeviceType(getAsync(), deviceType);
2793 }
2794 
2795 mlir::Value UpdateOp::getAsyncValue() {
2796  return getAsyncValue(mlir::acc::DeviceType::None);
2797 }
2798 
2799 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2801  return {};
2802 
2803  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2804  return getAsyncOperands()[*pos];
2805 
2806  return {};
2807 }
2808 
2809 bool UpdateOp::hasWaitOnly() {
2810  return hasWaitOnly(mlir::acc::DeviceType::None);
2811 }
2812 
2813 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2814  return hasDeviceType(getWaitOnly(), deviceType);
2815 }
2816 
2817 mlir::Operation::operand_range UpdateOp::getWaitValues() {
2818  return getWaitValues(mlir::acc::DeviceType::None);
2819 }
2820 
2822 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2824  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2825  getHasWaitDevnum(), deviceType);
2826 }
2827 
2828 mlir::Value UpdateOp::getWaitDevnum() {
2829  return getWaitDevnum(mlir::acc::DeviceType::None);
2830 }
2831 
2832 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2833  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2834  getWaitOperandsSegments(), getHasWaitDevnum(),
2835  deviceType);
2836 }
2837 
2838 //===----------------------------------------------------------------------===//
2839 // WaitOp
2840 //===----------------------------------------------------------------------===//
2841 
2842 LogicalResult acc::WaitOp::verify() {
2843  // The async attribute represent the async clause without value. Therefore the
2844  // attribute and operand cannot appear at the same time.
2845  if (getAsyncOperand() && getAsync())
2846  return emitError("async attribute cannot appear with asyncOperand");
2847 
2848  if (getWaitDevnum() && getWaitOperands().empty())
2849  return emitError("wait_devnum cannot appear without waitOperands");
2850 
2851  return success();
2852 }
2853 
2854 #define GET_OP_CLASSES
2855 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2856 
2857 #define GET_ATTRDEF_CLASSES
2858 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2859 
2860 #define GET_TYPEDEF_CLASSES
2861 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2862 
2863 //===----------------------------------------------------------------------===//
2864 // acc dialect utilities
2865 //===----------------------------------------------------------------------===//
2866 
2868  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2869  .Case<ACC_DATA_ENTRY_OPS>(
2870  [&](auto entry) { return entry.getVarPtr(); })
2871  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2872  [&](auto exit) { return exit.getVarPtr(); })
2873  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2874  return varPtr;
2875 }
2876 
2878  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2880  [&](auto dataClause) { return dataClause.getAccPtr(); })
2881  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2882  return accPtr;
2883 }
2884 
2886  auto varPtrPtr{
2888  .Case<ACC_DATA_ENTRY_OPS>(
2889  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2890  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2891  return varPtrPtr;
2892 }
2893 
2898  accDataClauseOp)
2899  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2901  dataClause.getBounds().begin(), dataClause.getBounds().end());
2902  })
2903  .Default([&](mlir::Operation *) {
2905  })};
2906  return bounds;
2907 }
2908 
2912  accDataClauseOp)
2913  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2915  dataClause.getAsyncOperands().begin(),
2916  dataClause.getAsyncOperands().end());
2917  })
2918  .Default([&](mlir::Operation *) {
2920  });
2921 }
2922 
2923 mlir::ArrayAttr
2926  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2927  return dataClause.getAsyncOperandsDeviceTypeAttr();
2928  })
2929  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2930 }
2931 
2932 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
2935  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
2936  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
2937 }
2938 
2939 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2940  auto name{
2942  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2943  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2944  return {};
2945  })};
2946  return name;
2947 }
2948 
2949 std::optional<mlir::acc::DataClause>
2951  auto dataClause{
2953  accDataEntryOp)
2954  .Case<ACC_DATA_ENTRY_OPS>(
2955  [&](auto entry) { return entry.getDataClause(); })
2956  .Default([&](mlir::Operation *) { return std::nullopt; })};
2957  return dataClause;
2958 }
2959 
2961  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
2962  .Case<ACC_DATA_ENTRY_OPS>(
2963  [&](auto entry) { return entry.getImplicit(); })
2964  .Default([&](mlir::Operation *) { return false; })};
2965  return implicit;
2966 }
2967 
2969  auto dataOperands{
2972  [&](auto entry) { return entry.getDataClauseOperands(); })
2973  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
2974  return dataOperands;
2975 }
2976 
2979  auto dataOperands{
2982  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
2983  .Default([&](mlir::Operation *) { return nullptr; })};
2984  return dataOperands;
2985 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2166
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2569
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:446
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1786
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:748
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:1801
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:460
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1148
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2472
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1159
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1064
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:77
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2626
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1595
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1313
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2357
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:97
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2073
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:677
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1193
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:823
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:121
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:934
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2104
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2599
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2508
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1047
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1220
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2498
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1001
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:661
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:153
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1614
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:536
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:978
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:108
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:692
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1293
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:83
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1741
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:137
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1231
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:173
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:758
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2439
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:984
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1339
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:641
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:67
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:43
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:51
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:211
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2867
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:2950
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:2978
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:2895
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:2968
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2939
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:2960
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:2910
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:2885
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2877
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2932
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:140
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:2924
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.