MLIR  19.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 using namespace acc;
24 
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
29 
30 namespace {
31 struct MemRefPointerLikeModel
32  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
33  MemRefType> {
34  Type getElementType(Type pointer) const {
35  return llvm::cast<MemRefType>(pointer).getElementType();
36  }
37 };
38 
39 struct LLVMPointerPointerLikeModel
40  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
41  LLVM::LLVMPointerType> {
42  Type getElementType(Type pointer) const { return Type(); }
43 };
44 } // namespace
45 
46 //===----------------------------------------------------------------------===//
47 // OpenACC operations
48 //===----------------------------------------------------------------------===//
49 
50 void OpenACCDialect::initialize() {
51  addOperations<
52 #define GET_OP_LIST
53 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
54  >();
55  addAttributes<
56 #define GET_ATTRDEF_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
58  >();
59  addTypes<
60 #define GET_TYPEDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
62  >();
63 
64  // By attaching interfaces here, we make the OpenACC dialect dependent on
65  // the other dialects. This is probably better than having dialects like LLVM
66  // and memref be dependent on OpenACC.
67  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
68  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
69  *getContext());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // device_type support helpers
74 //===----------------------------------------------------------------------===//
75 
76 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
77  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
78  return true;
79  return false;
80 }
81 
82 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
83  mlir::acc::DeviceType deviceType) {
84  if (!hasDeviceTypeValues(arrayAttr))
85  return false;
86 
87  for (auto attr : *arrayAttr) {
88  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89  if (deviceTypeAttr.getValue() == deviceType)
90  return true;
91  }
92 
93  return false;
94 }
95 
97  std::optional<mlir::ArrayAttr> deviceTypes) {
98  if (!hasDeviceTypeValues(deviceTypes))
99  return;
100 
101  p << "[";
102  llvm::interleaveComma(*deviceTypes, p,
103  [&](mlir::Attribute attr) { p << attr; });
104  p << "]";
105 }
106 
107 static std::optional<unsigned> findSegment(ArrayAttr segments,
108  mlir::acc::DeviceType deviceType) {
109  unsigned segmentIdx = 0;
110  for (auto attr : segments) {
111  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
112  if (deviceTypeAttr.getValue() == deviceType)
113  return std::make_optional(segmentIdx);
114  ++segmentIdx;
115  }
116  return std::nullopt;
117 }
118 
120 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
122  std::optional<llvm::ArrayRef<int32_t>> segments,
123  mlir::acc::DeviceType deviceType) {
124  if (!arrayAttr)
125  return range.take_front(0);
126  if (auto pos = findSegment(*arrayAttr, deviceType)) {
127  int32_t nbOperandsBefore = 0;
128  for (unsigned i = 0; i < *pos; ++i)
129  nbOperandsBefore += (*segments)[i];
130  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
131  }
132  return range.take_front(0);
133 }
134 
135 static mlir::Value
136 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
138  std::optional<llvm::ArrayRef<int32_t>> segments,
139  std::optional<mlir::ArrayAttr> hasWaitDevnum,
140  mlir::acc::DeviceType deviceType) {
141  if (!hasDeviceTypeValues(deviceTypeAttr))
142  return {};
143  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
144  if (hasWaitDevnum->getValue()[*pos])
145  return getValuesFromSegments(deviceTypeAttr, operands, segments,
146  deviceType)
147  .front();
148  return {};
149 }
150 
152 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
154  std::optional<llvm::ArrayRef<int32_t>> segments,
155  std::optional<mlir::ArrayAttr> hasWaitDevnum,
156  mlir::acc::DeviceType deviceType) {
157  auto range =
158  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
159  if (range.empty())
160  return range;
161  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
162  if (hasWaitDevnum && *hasWaitDevnum) {
163  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
164  if (boolAttr.getValue())
165  return range.drop_front(1); // first value is devnum
166  }
167  }
168  return range;
169 }
170 
171 template <typename Op>
173  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
174  ++dtypeInt) {
175  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
176 
177  // The async attribute represent the async clause without value. Therefore
178  // the attribute and operand cannot appear at the same time.
179  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
180  op.hasAsyncOnly(dtype))
181  return op.emitError("async attribute cannot appear with asyncOperand");
182 
183  // The wait attribute represent the wait clause without values. Therefore
184  // the attribute and operands cannot appear at the same time.
185  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
186  op.hasWaitOnly(dtype))
187  return op.emitError("wait attribute cannot appear with waitOperands");
188  }
189  return success();
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // DataBoundsOp
194 //===----------------------------------------------------------------------===//
196  auto extent = getExtent();
197  auto upperbound = getUpperbound();
198  if (!extent && !upperbound)
199  return emitError("expected extent or upperbound.");
200  return success();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // PrivateOp
205 //===----------------------------------------------------------------------===//
207  if (getDataClause() != acc::DataClause::acc_private)
208  return emitError(
209  "data clause associated with private operation must match its intent");
210  return success();
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // FirstprivateOp
215 //===----------------------------------------------------------------------===//
217  if (getDataClause() != acc::DataClause::acc_firstprivate)
218  return emitError("data clause associated with firstprivate operation must "
219  "match its intent");
220  return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // ReductionOp
225 //===----------------------------------------------------------------------===//
227  if (getDataClause() != acc::DataClause::acc_reduction)
228  return emitError("data clause associated with reduction operation must "
229  "match its intent");
230  return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // DevicePtrOp
235 //===----------------------------------------------------------------------===//
237  if (getDataClause() != acc::DataClause::acc_deviceptr)
238  return emitError("data clause associated with deviceptr operation must "
239  "match its intent");
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // PresentOp
245 //===----------------------------------------------------------------------===//
247  if (getDataClause() != acc::DataClause::acc_present)
248  return emitError(
249  "data clause associated with present operation must match its intent");
250  return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CopyinOp
255 //===----------------------------------------------------------------------===//
257  // Test for all clauses this operation can be decomposed from:
258  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
259  getDataClause() != acc::DataClause::acc_copyin_readonly &&
260  getDataClause() != acc::DataClause::acc_copy &&
261  getDataClause() != acc::DataClause::acc_reduction)
262  return emitError(
263  "data clause associated with copyin operation must match its intent"
264  " or specify original clause this operation was decomposed from");
265  return success();
266 }
267 
268 bool acc::CopyinOp::isCopyinReadonly() {
269  return getDataClause() == acc::DataClause::acc_copyin_readonly;
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // CreateOp
274 //===----------------------------------------------------------------------===//
276  // Test for all clauses this operation can be decomposed from:
277  if (getDataClause() != acc::DataClause::acc_create &&
278  getDataClause() != acc::DataClause::acc_create_zero &&
279  getDataClause() != acc::DataClause::acc_copyout &&
280  getDataClause() != acc::DataClause::acc_copyout_zero)
281  return emitError(
282  "data clause associated with create operation must match its intent"
283  " or specify original clause this operation was decomposed from");
284  return success();
285 }
286 
287 bool acc::CreateOp::isCreateZero() {
288  // The zero modifier is encoded in the data clause.
289  return getDataClause() == acc::DataClause::acc_create_zero ||
290  getDataClause() == acc::DataClause::acc_copyout_zero;
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // NoCreateOp
295 //===----------------------------------------------------------------------===//
297  if (getDataClause() != acc::DataClause::acc_no_create)
298  return emitError("data clause associated with no_create operation must "
299  "match its intent");
300  return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // AttachOp
305 //===----------------------------------------------------------------------===//
307  if (getDataClause() != acc::DataClause::acc_attach)
308  return emitError(
309  "data clause associated with attach operation must match its intent");
310  return success();
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // DeclareDeviceResidentOp
315 //===----------------------------------------------------------------------===//
316 
318  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
319  return emitError("data clause associated with device_resident operation "
320  "must match its intent");
321  return success();
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // DeclareLinkOp
326 //===----------------------------------------------------------------------===//
327 
329  if (getDataClause() != acc::DataClause::acc_declare_link)
330  return emitError(
331  "data clause associated with link operation must match its intent");
332  return success();
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // CopyoutOp
337 //===----------------------------------------------------------------------===//
339  // Test for all clauses this operation can be decomposed from:
340  if (getDataClause() != acc::DataClause::acc_copyout &&
341  getDataClause() != acc::DataClause::acc_copyout_zero &&
342  getDataClause() != acc::DataClause::acc_copy &&
343  getDataClause() != acc::DataClause::acc_reduction)
344  return emitError(
345  "data clause associated with copyout operation must match its intent"
346  " or specify original clause this operation was decomposed from");
347  if (!getVarPtr() || !getAccPtr())
348  return emitError("must have both host and device pointers");
349  return success();
350 }
351 
352 bool acc::CopyoutOp::isCopyoutZero() {
353  return getDataClause() == acc::DataClause::acc_copyout_zero;
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // DeleteOp
358 //===----------------------------------------------------------------------===//
360  // Test for all clauses this operation can be decomposed from:
361  if (getDataClause() != acc::DataClause::acc_delete &&
362  getDataClause() != acc::DataClause::acc_create &&
363  getDataClause() != acc::DataClause::acc_create_zero &&
364  getDataClause() != acc::DataClause::acc_copyin &&
365  getDataClause() != acc::DataClause::acc_copyin_readonly &&
366  getDataClause() != acc::DataClause::acc_present &&
367  getDataClause() != acc::DataClause::acc_declare_device_resident &&
368  getDataClause() != acc::DataClause::acc_declare_link)
369  return emitError(
370  "data clause associated with delete operation must match its intent"
371  " or specify original clause this operation was decomposed from");
372  if (!getAccPtr())
373  return emitError("must have device pointer");
374  return success();
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // DetachOp
379 //===----------------------------------------------------------------------===//
381  // Test for all clauses this operation can be decomposed from:
382  if (getDataClause() != acc::DataClause::acc_detach &&
383  getDataClause() != acc::DataClause::acc_attach)
384  return emitError(
385  "data clause associated with detach operation must match its intent"
386  " or specify original clause this operation was decomposed from");
387  if (!getAccPtr())
388  return emitError("must have device pointer");
389  return success();
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // HostOp
394 //===----------------------------------------------------------------------===//
396  // Test for all clauses this operation can be decomposed from:
397  if (getDataClause() != acc::DataClause::acc_update_host &&
398  getDataClause() != acc::DataClause::acc_update_self)
399  return emitError(
400  "data clause associated with host operation must match its intent"
401  " or specify original clause this operation was decomposed from");
402  if (!getVarPtr() || !getAccPtr())
403  return emitError("must have both host and device pointers");
404  return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // DeviceOp
409 //===----------------------------------------------------------------------===//
411  // Test for all clauses this operation can be decomposed from:
412  if (getDataClause() != acc::DataClause::acc_update_device)
413  return emitError(
414  "data clause associated with device operation must match its intent"
415  " or specify original clause this operation was decomposed from");
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // UseDeviceOp
421 //===----------------------------------------------------------------------===//
423  // Test for all clauses this operation can be decomposed from:
424  if (getDataClause() != acc::DataClause::acc_use_device)
425  return emitError(
426  "data clause associated with use_device operation must match its intent"
427  " or specify original clause this operation was decomposed from");
428  return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // CacheOp
433 //===----------------------------------------------------------------------===//
435  // Test for all clauses this operation can be decomposed from:
436  if (getDataClause() != acc::DataClause::acc_cache &&
437  getDataClause() != acc::DataClause::acc_cache_readonly)
438  return emitError(
439  "data clause associated with cache operation must match its intent"
440  " or specify original clause this operation was decomposed from");
441  return success();
442 }
443 
444 template <typename StructureOp>
446  unsigned nRegions = 1) {
447 
448  SmallVector<Region *, 2> regions;
449  for (unsigned i = 0; i < nRegions; ++i)
450  regions.push_back(state.addRegion());
451 
452  for (Region *region : regions)
453  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
454  return failure();
455 
456  return success();
457 }
458 
459 static bool isComputeOperation(Operation *op) {
460  return isa<acc::ParallelOp, acc::LoopOp>(op);
461 }
462 
463 namespace {
464 /// Pattern to remove operation without region that have constant false `ifCond`
465 /// and remove the condition from the operation if the `ifCond` is a true
466 /// constant.
467 template <typename OpTy>
468 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
470 
471  LogicalResult matchAndRewrite(OpTy op,
472  PatternRewriter &rewriter) const override {
473  // Early return if there is no condition.
474  Value ifCond = op.getIfCond();
475  if (!ifCond)
476  return failure();
477 
478  IntegerAttr constAttr;
479  if (!matchPattern(ifCond, m_Constant(&constAttr)))
480  return failure();
481  if (constAttr.getInt())
482  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
483  else
484  rewriter.eraseOp(op);
485 
486  return success();
487  }
488 };
489 
490 /// Replaces the given op with the contents of the given single-block region,
491 /// using the operands of the block terminator to replace operation results.
492 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
493  Region &region, ValueRange blockArgs = {}) {
494  assert(llvm::hasSingleElement(region) && "expected single-region block");
495  Block *block = &region.front();
496  Operation *terminator = block->getTerminator();
497  ValueRange results = terminator->getOperands();
498  rewriter.inlineBlockBefore(block, op, blockArgs);
499  rewriter.replaceOp(op, results);
500  rewriter.eraseOp(terminator);
501 }
502 
503 /// Pattern to remove operation with region that have constant false `ifCond`
504 /// and remove the condition from the operation if the `ifCond` is constant
505 /// true.
506 template <typename OpTy>
507 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
509 
510  LogicalResult matchAndRewrite(OpTy op,
511  PatternRewriter &rewriter) const override {
512  // Early return if there is no condition.
513  Value ifCond = op.getIfCond();
514  if (!ifCond)
515  return failure();
516 
517  IntegerAttr constAttr;
518  if (!matchPattern(ifCond, m_Constant(&constAttr)))
519  return failure();
520  if (constAttr.getInt())
521  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
522  else
523  replaceOpWithRegion(rewriter, op, op.getRegion());
524 
525  return success();
526  }
527 };
528 
529 } // namespace
530 
531 //===----------------------------------------------------------------------===//
532 // PrivateRecipeOp
533 //===----------------------------------------------------------------------===//
534 
536  Operation *op, Region &region, StringRef regionType, StringRef regionName,
537  Type type, bool verifyYield, bool optional = false) {
538  if (optional && region.empty())
539  return success();
540 
541  if (region.empty())
542  return op->emitOpError() << "expects non-empty " << regionName << " region";
543  Block &firstBlock = region.front();
544  if (firstBlock.getNumArguments() < 1 ||
545  firstBlock.getArgument(0).getType() != type)
546  return op->emitOpError() << "expects " << regionName
547  << " region first "
548  "argument of the "
549  << regionType << " type";
550 
551  if (verifyYield) {
552  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
553  if (yieldOp.getOperands().size() != 1 ||
554  yieldOp.getOperands().getTypes()[0] != type)
555  return op->emitOpError() << "expects " << regionName
556  << " region to "
557  "yield a value of the "
558  << regionType << " type";
559  }
560  }
561  return success();
562 }
563 
564 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
565  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
566  "privatization", "init", getType(),
567  /*verifyYield=*/false)))
568  return failure();
570  *this, getDestroyRegion(), "privatization", "destroy", getType(),
571  /*verifyYield=*/false, /*optional=*/true)))
572  return failure();
573  return success();
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // FirstprivateRecipeOp
578 //===----------------------------------------------------------------------===//
579 
580 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
581  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
582  "privatization", "init", getType(),
583  /*verifyYield=*/false)))
584  return failure();
585 
586  if (getCopyRegion().empty())
587  return emitOpError() << "expects non-empty copy region";
588 
589  Block &firstBlock = getCopyRegion().front();
590  if (firstBlock.getNumArguments() < 2 ||
591  firstBlock.getArgument(0).getType() != getType())
592  return emitOpError() << "expects copy region with two arguments of the "
593  "privatization type";
594 
595  if (getDestroyRegion().empty())
596  return success();
597 
598  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
599  "privatization", "destroy",
600  getType(), /*verifyYield=*/false)))
601  return failure();
602 
603  return success();
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // ReductionRecipeOp
608 //===----------------------------------------------------------------------===//
609 
610 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
611  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
612  "init", getType(),
613  /*verifyYield=*/false)))
614  return failure();
615 
616  if (getCombinerRegion().empty())
617  return emitOpError() << "expects non-empty combiner region";
618 
619  Block &reductionBlock = getCombinerRegion().front();
620  if (reductionBlock.getNumArguments() < 2 ||
621  reductionBlock.getArgument(0).getType() != getType() ||
622  reductionBlock.getArgument(1).getType() != getType())
623  return emitOpError() << "expects combiner region with the first two "
624  << "arguments of the reduction type";
625 
626  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
627  if (yieldOp.getOperands().size() != 1 ||
628  yieldOp.getOperands().getTypes()[0] != getType())
629  return emitOpError() << "expects combiner region to yield a value "
630  "of the reduction type";
631  }
632 
633  return success();
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // Custom parser and printer verifier for private clause
638 //===----------------------------------------------------------------------===//
639 
641  mlir::OpAsmParser &parser,
643  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
645  if (failed(parser.parseCommaSeparatedList([&]() {
646  if (parser.parseAttribute(attributes.emplace_back()) ||
647  parser.parseArrow() ||
648  parser.parseOperand(operands.emplace_back()) ||
649  parser.parseColonType(types.emplace_back()))
650  return failure();
651  return success();
652  })))
653  return failure();
654  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
655  attributes.end());
656  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
657  return success();
658 }
659 
661  mlir::OperandRange operands,
662  mlir::TypeRange types,
663  std::optional<mlir::ArrayAttr> attributes) {
664  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
665  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
666  << std::get<1>(it).getType();
667  });
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // ParallelOp
672 //===----------------------------------------------------------------------===//
673 
674 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
675 template <typename Op>
677  const mlir::ValueRange &operands) {
678  for (mlir::Value operand : operands)
679  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
680  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
681  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
682  operand.getDefiningOp()))
683  return op.emitError(
684  "expect data entry/exit operation or acc.getdeviceptr "
685  "as defining op");
686  return success();
687 }
688 
689 template <typename Op>
690 static LogicalResult
691 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
692  mlir::OperandRange operands, llvm::StringRef operandName,
693  llvm::StringRef symbolName, bool checkOperandType = true) {
694  if (!operands.empty()) {
695  if (!attributes || attributes->size() != operands.size())
696  return op->emitOpError()
697  << "expected as many " << symbolName << " symbol reference as "
698  << operandName << " operands";
699  } else {
700  if (attributes)
701  return op->emitOpError()
702  << "unexpected " << symbolName << " symbol reference";
703  return success();
704  }
705 
707  for (auto args : llvm::zip(operands, *attributes)) {
708  mlir::Value operand = std::get<0>(args);
709 
710  if (!set.insert(operand).second)
711  return op->emitOpError()
712  << operandName << " operand appears more than once";
713 
714  mlir::Type varType = operand.getType();
715  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
716  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
717  if (!decl)
718  return op->emitOpError()
719  << "expected symbol reference " << symbolRef << " to point to a "
720  << operandName << " declaration";
721 
722  if (checkOperandType && decl.getType() && decl.getType() != varType)
723  return op->emitOpError() << "expected " << operandName << " (" << varType
724  << ") to be the same type as " << operandName
725  << " declaration (" << decl.getType() << ")";
726  }
727 
728  return success();
729 }
730 
731 unsigned ParallelOp::getNumDataOperands() {
732  return getReductionOperands().size() + getGangPrivateOperands().size() +
733  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
734 }
735 
736 Value ParallelOp::getDataOperand(unsigned i) {
737  unsigned numOptional = getAsyncOperands().size();
738  numOptional += getNumGangs().size();
739  numOptional += getNumWorkers().size();
740  numOptional += getVectorLength().size();
741  numOptional += getIfCond() ? 1 : 0;
742  numOptional += getSelfCond() ? 1 : 0;
743  return getOperand(getWaitOperands().size() + numOptional + i);
744 }
745 
746 template <typename Op>
748  ArrayAttr deviceTypes,
749  llvm::StringRef keyword) {
750  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
751  return op.emitOpError() << keyword << " operands count must match "
752  << keyword << " device_type count";
753  return success();
754 }
755 
756 template <typename Op>
758  Op op, OperandRange operands, DenseI32ArrayAttr segments,
759  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
760  std::size_t numOperandsInSegments = 0;
761 
762  if (!segments)
763  return success();
764 
765  for (auto segCount : segments.asArrayRef()) {
766  if (maxInSegment != 0 && segCount > maxInSegment)
767  return op.emitOpError() << keyword << " expects a maximum of "
768  << maxInSegment << " values per segment";
769  numOperandsInSegments += segCount;
770  }
771  if (numOperandsInSegments != operands.size())
772  return op.emitOpError()
773  << keyword << " operand count does not match count in segments";
774  if (deviceTypes.getValue().size() != (size_t)segments.size())
775  return op.emitOpError()
776  << keyword << " segment count does not match device_type count";
777  return success();
778 }
779 
781  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
782  *this, getPrivatizations(), getGangPrivateOperands(), "private",
783  "privatizations", /*checkOperandType=*/false)))
784  return failure();
785  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
786  *this, getReductionRecipes(), getReductionOperands(), "reduction",
787  "reductions", false)))
788  return failure();
789 
791  *this, getNumGangs(), getNumGangsSegmentsAttr(),
792  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
793  return failure();
794 
796  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
797  getWaitOperandsDeviceTypeAttr(), "wait")))
798  return failure();
799 
800  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
801  getNumWorkersDeviceTypeAttr(),
802  "num_workers")))
803  return failure();
804 
805  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
806  getVectorLengthDeviceTypeAttr(),
807  "vector_length")))
808  return failure();
809 
810  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
811  getAsyncOperandsDeviceTypeAttr(),
812  "async")))
813  return failure();
814 
815  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
816  return failure();
817 
818  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
819 }
820 
821 static mlir::Value
822 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
824  mlir::acc::DeviceType deviceType) {
825  if (!arrayAttr)
826  return {};
827  if (auto pos = findSegment(*arrayAttr, deviceType))
828  return range[*pos];
829  return {};
830 }
831 
832 bool acc::ParallelOp::hasAsyncOnly() {
833  return hasAsyncOnly(mlir::acc::DeviceType::None);
834 }
835 
836 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
837  return hasDeviceType(getAsyncOnly(), deviceType);
838 }
839 
840 mlir::Value acc::ParallelOp::getAsyncValue() {
841  return getAsyncValue(mlir::acc::DeviceType::None);
842 }
843 
844 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
845  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
846  getAsyncOperands(), deviceType);
847 }
848 
849 mlir::Value acc::ParallelOp::getNumWorkersValue() {
850  return getNumWorkersValue(mlir::acc::DeviceType::None);
851 }
852 
854 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
855  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
856  deviceType);
857 }
858 
859 mlir::Value acc::ParallelOp::getVectorLengthValue() {
860  return getVectorLengthValue(mlir::acc::DeviceType::None);
861 }
862 
864 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
865  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
866  getVectorLength(), deviceType);
867 }
868 
869 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
870  return getNumGangsValues(mlir::acc::DeviceType::None);
871 }
872 
874 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
875  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
876  getNumGangsSegments(), deviceType);
877 }
878 
879 bool acc::ParallelOp::hasWaitOnly() {
880  return hasWaitOnly(mlir::acc::DeviceType::None);
881 }
882 
883 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
884  return hasDeviceType(getWaitOnly(), deviceType);
885 }
886 
887 mlir::Operation::operand_range ParallelOp::getWaitValues() {
888  return getWaitValues(mlir::acc::DeviceType::None);
889 }
890 
892 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
894  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
895  getHasWaitDevnum(), deviceType);
896 }
897 
898 mlir::Value ParallelOp::getWaitDevnum() {
899  return getWaitDevnum(mlir::acc::DeviceType::None);
900 }
901 
902 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
903  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
904  getWaitOperandsSegments(), getHasWaitDevnum(),
905  deviceType);
906 }
907 
909  mlir::OpAsmParser &parser,
911  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
912  mlir::DenseI32ArrayAttr &segments) {
915 
916  do {
917  if (failed(parser.parseLBrace()))
918  return failure();
919 
920  int32_t crtOperandsSize = operands.size();
921  if (failed(parser.parseCommaSeparatedList(
923  if (parser.parseOperand(operands.emplace_back()) ||
924  parser.parseColonType(types.emplace_back()))
925  return failure();
926  return success();
927  })))
928  return failure();
929  seg.push_back(operands.size() - crtOperandsSize);
930 
931  if (failed(parser.parseRBrace()))
932  return failure();
933 
934  if (succeeded(parser.parseOptionalLSquare())) {
935  if (parser.parseAttribute(attributes.emplace_back()) ||
936  parser.parseRSquare())
937  return failure();
938  } else {
939  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
941  }
942  } while (succeeded(parser.parseOptionalComma()));
943 
944  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
945  attributes.end());
946  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
947  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
948 
949  return success();
950 }
951 
953  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
954  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
955  p << " [" << attr << "]";
956 }
957 
959  mlir::OperandRange operands, mlir::TypeRange types,
960  std::optional<mlir::ArrayAttr> deviceTypes,
961  std::optional<mlir::DenseI32ArrayAttr> segments) {
962  unsigned opIdx = 0;
963  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
964  p << "{";
965  llvm::interleaveComma(
966  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
967  p << operands[opIdx] << " : " << operands[opIdx].getType();
968  ++opIdx;
969  });
970  p << "}";
971  printSingleDeviceType(p, it.value());
972  });
973 }
974 
976  mlir::OpAsmParser &parser,
978  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
979  mlir::DenseI32ArrayAttr &segments) {
982 
983  do {
984  if (failed(parser.parseLBrace()))
985  return failure();
986 
987  int32_t crtOperandsSize = operands.size();
988 
989  if (failed(parser.parseCommaSeparatedList(
991  if (parser.parseOperand(operands.emplace_back()) ||
992  parser.parseColonType(types.emplace_back()))
993  return failure();
994  return success();
995  })))
996  return failure();
997 
998  seg.push_back(operands.size() - crtOperandsSize);
999 
1000  if (failed(parser.parseRBrace()))
1001  return failure();
1002 
1003  if (succeeded(parser.parseOptionalLSquare())) {
1004  if (parser.parseAttribute(attributes.emplace_back()) ||
1005  parser.parseRSquare())
1006  return failure();
1007  } else {
1008  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1010  }
1011  } while (succeeded(parser.parseOptionalComma()));
1012 
1013  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1014  attributes.end());
1015  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1016  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1017 
1018  return success();
1019 }
1020 
1023  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1024  std::optional<mlir::DenseI32ArrayAttr> segments) {
1025  unsigned opIdx = 0;
1026  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1027  p << "{";
1028  llvm::interleaveComma(
1029  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1030  p << operands[opIdx] << " : " << operands[opIdx].getType();
1031  ++opIdx;
1032  });
1033  p << "}";
1034  printSingleDeviceType(p, it.value());
1035  });
1036 }
1037 
1039  mlir::OpAsmParser &parser,
1041  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1042  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1043  mlir::ArrayAttr &keywordOnly) {
1044  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1046 
1047  bool needCommaBeforeOperands = false;
1048 
1049  // Keyword only
1050  if (failed(parser.parseOptionalLParen())) {
1051  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1053  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1054  return success();
1055  }
1056 
1057  // Parse keyword only attributes
1058  if (succeeded(parser.parseOptionalLSquare())) {
1059  if (failed(parser.parseCommaSeparatedList([&]() {
1060  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1061  return failure();
1062  return success();
1063  })))
1064  return failure();
1065  if (parser.parseRSquare())
1066  return failure();
1067  needCommaBeforeOperands = true;
1068  }
1069 
1070  if (needCommaBeforeOperands && failed(parser.parseComma()))
1071  return failure();
1072 
1073  do {
1074  if (failed(parser.parseLBrace()))
1075  return failure();
1076 
1077  int32_t crtOperandsSize = operands.size();
1078 
1079  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1080  if (failed(parser.parseColon()))
1081  return failure();
1082  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1083  } else {
1084  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1085  }
1086 
1087  if (failed(parser.parseCommaSeparatedList(
1089  if (parser.parseOperand(operands.emplace_back()) ||
1090  parser.parseColonType(types.emplace_back()))
1091  return failure();
1092  return success();
1093  })))
1094  return failure();
1095 
1096  seg.push_back(operands.size() - crtOperandsSize);
1097 
1098  if (failed(parser.parseRBrace()))
1099  return failure();
1100 
1101  if (succeeded(parser.parseOptionalLSquare())) {
1102  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1103  parser.parseRSquare())
1104  return failure();
1105  } else {
1106  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1108  }
1109  } while (succeeded(parser.parseOptionalComma()));
1110 
1111  if (failed(parser.parseRParen()))
1112  return failure();
1113 
1114  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1115  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1116  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1117  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1118 
1119  return success();
1120 }
1121 
1122 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1123  if (!hasDeviceTypeValues(attrs))
1124  return false;
1125  if (attrs->size() != 1)
1126  return false;
1127  if (auto deviceTypeAttr =
1128  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1129  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1130  return false;
1131 }
1132 
1134  mlir::OperandRange operands, mlir::TypeRange types,
1135  std::optional<mlir::ArrayAttr> deviceTypes,
1136  std::optional<mlir::DenseI32ArrayAttr> segments,
1137  std::optional<mlir::ArrayAttr> hasDevNum,
1138  std::optional<mlir::ArrayAttr> keywordOnly) {
1139 
1140  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1141  return;
1142 
1143  p << "(";
1144 
1145  printDeviceTypes(p, keywordOnly);
1146  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1147  p << ", ";
1148 
1149  unsigned opIdx = 0;
1150  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1151  p << "{";
1152  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1153  if (boolAttr && boolAttr.getValue())
1154  p << "devnum: ";
1155  llvm::interleaveComma(
1156  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1157  p << operands[opIdx] << " : " << operands[opIdx].getType();
1158  ++opIdx;
1159  });
1160  p << "}";
1161  printSingleDeviceType(p, it.value());
1162  });
1163 
1164  p << ")";
1165 }
1166 
1168  mlir::OpAsmParser &parser,
1170  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1172  if (failed(parser.parseCommaSeparatedList([&]() {
1173  if (parser.parseOperand(operands.emplace_back()) ||
1174  parser.parseColonType(types.emplace_back()))
1175  return failure();
1176  if (succeeded(parser.parseOptionalLSquare())) {
1177  if (parser.parseAttribute(attributes.emplace_back()) ||
1178  parser.parseRSquare())
1179  return failure();
1180  } else {
1181  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1182  parser.getContext(), mlir::acc::DeviceType::None));
1183  }
1184  return success();
1185  })))
1186  return failure();
1187  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1188  attributes.end());
1189  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1190  return success();
1191 }
1192 
1193 static void
1195  mlir::OperandRange operands, mlir::TypeRange types,
1196  std::optional<mlir::ArrayAttr> deviceTypes) {
1197  if (!hasDeviceTypeValues(deviceTypes))
1198  return;
1199  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1200  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1201  printSingleDeviceType(p, std::get<0>(it));
1202  });
1203 }
1204 
1206  mlir::OpAsmParser &parser,
1208  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1209  mlir::ArrayAttr &keywordOnlyDeviceType) {
1210 
1211  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1212  bool needCommaBeforeOperands = false;
1213 
1214  if (failed(parser.parseOptionalLParen())) {
1215  // Keyword only
1216  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1218  keywordOnlyDeviceType =
1219  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1220  return success();
1221  }
1222 
1223  // Parse keyword only attributes
1224  if (succeeded(parser.parseOptionalLSquare())) {
1225  // Parse keyword only attributes
1226  if (failed(parser.parseCommaSeparatedList([&]() {
1227  if (parser.parseAttribute(
1228  keywordOnlyDeviceTypeAttributes.emplace_back()))
1229  return failure();
1230  return success();
1231  })))
1232  return failure();
1233  if (parser.parseRSquare())
1234  return failure();
1235  needCommaBeforeOperands = true;
1236  }
1237 
1238  if (needCommaBeforeOperands && failed(parser.parseComma()))
1239  return failure();
1240 
1242  if (failed(parser.parseCommaSeparatedList([&]() {
1243  if (parser.parseOperand(operands.emplace_back()) ||
1244  parser.parseColonType(types.emplace_back()))
1245  return failure();
1246  if (succeeded(parser.parseOptionalLSquare())) {
1247  if (parser.parseAttribute(attributes.emplace_back()) ||
1248  parser.parseRSquare())
1249  return failure();
1250  } else {
1251  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1252  parser.getContext(), mlir::acc::DeviceType::None));
1253  }
1254  return success();
1255  })))
1256  return failure();
1257 
1258  if (failed(parser.parseRParen()))
1259  return failure();
1260 
1261  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1262  attributes.end());
1263  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1264  return success();
1265 }
1266 
1269  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1270  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1271 
1272  if (operands.begin() == operands.end() &&
1273  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1274  return;
1275  }
1276 
1277  p << "(";
1278  printDeviceTypes(p, keywordOnlyDeviceTypes);
1279  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1280  hasDeviceTypeValues(deviceTypes))
1281  p << ", ";
1282  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1283  p << ")";
1284 }
1285 
1286 static ParseResult
1288  mlir::acc::CombinedConstructsTypeAttr &attr) {
1289  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1290  if (parser.parseLParen())
1291  return failure();
1292  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1294  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1295  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1297  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1298  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1300  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1301  } else {
1302  parser.emitError(parser.getCurrentLocation(),
1303  "expected compute construct name");
1304  return failure();
1305  }
1306  if (parser.parseRParen())
1307  return failure();
1308  }
1309  return success();
1310 }
1311 
1312 static void
1314  mlir::acc::CombinedConstructsTypeAttr attr) {
1315  if (attr) {
1316  switch (attr.getValue()) {
1317  case mlir::acc::CombinedConstructsType::KernelsLoop:
1318  p << "combined(kernels)";
1319  break;
1320  case mlir::acc::CombinedConstructsType::ParallelLoop:
1321  p << "combined(parallel)";
1322  break;
1323  case mlir::acc::CombinedConstructsType::SerialLoop:
1324  p << "combined(serial)";
1325  break;
1326  };
1327  }
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // SerialOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 unsigned SerialOp::getNumDataOperands() {
1335  return getReductionOperands().size() + getGangPrivateOperands().size() +
1336  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1337 }
1338 
1339 Value SerialOp::getDataOperand(unsigned i) {
1340  unsigned numOptional = getAsyncOperands().size();
1341  numOptional += getIfCond() ? 1 : 0;
1342  numOptional += getSelfCond() ? 1 : 0;
1343  return getOperand(getWaitOperands().size() + numOptional + i);
1344 }
1345 
1346 bool acc::SerialOp::hasAsyncOnly() {
1347  return hasAsyncOnly(mlir::acc::DeviceType::None);
1348 }
1349 
1350 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1351  return hasDeviceType(getAsyncOnly(), deviceType);
1352 }
1353 
1354 mlir::Value acc::SerialOp::getAsyncValue() {
1355  return getAsyncValue(mlir::acc::DeviceType::None);
1356 }
1357 
1358 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1359  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1360  getAsyncOperands(), deviceType);
1361 }
1362 
1363 bool acc::SerialOp::hasWaitOnly() {
1364  return hasWaitOnly(mlir::acc::DeviceType::None);
1365 }
1366 
1367 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1368  return hasDeviceType(getWaitOnly(), deviceType);
1369 }
1370 
1371 mlir::Operation::operand_range SerialOp::getWaitValues() {
1372  return getWaitValues(mlir::acc::DeviceType::None);
1373 }
1374 
1376 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1378  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1379  getHasWaitDevnum(), deviceType);
1380 }
1381 
1382 mlir::Value SerialOp::getWaitDevnum() {
1383  return getWaitDevnum(mlir::acc::DeviceType::None);
1384 }
1385 
1386 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1387  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1388  getWaitOperandsSegments(), getHasWaitDevnum(),
1389  deviceType);
1390 }
1391 
1393  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1394  *this, getPrivatizations(), getGangPrivateOperands(), "private",
1395  "privatizations", /*checkOperandType=*/false)))
1396  return failure();
1397  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1398  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1399  "reductions", false)))
1400  return failure();
1401 
1403  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1404  getWaitOperandsDeviceTypeAttr(), "wait")))
1405  return failure();
1406 
1407  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1408  getAsyncOperandsDeviceTypeAttr(),
1409  "async")))
1410  return failure();
1411 
1412  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1413  return failure();
1414 
1415  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // KernelsOp
1420 //===----------------------------------------------------------------------===//
1421 
1422 unsigned KernelsOp::getNumDataOperands() {
1423  return getDataClauseOperands().size();
1424 }
1425 
1426 Value KernelsOp::getDataOperand(unsigned i) {
1427  unsigned numOptional = getAsyncOperands().size();
1428  numOptional += getWaitOperands().size();
1429  numOptional += getNumGangs().size();
1430  numOptional += getNumWorkers().size();
1431  numOptional += getVectorLength().size();
1432  numOptional += getIfCond() ? 1 : 0;
1433  numOptional += getSelfCond() ? 1 : 0;
1434  return getOperand(numOptional + i);
1435 }
1436 
1437 bool acc::KernelsOp::hasAsyncOnly() {
1438  return hasAsyncOnly(mlir::acc::DeviceType::None);
1439 }
1440 
1441 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1442  return hasDeviceType(getAsyncOnly(), deviceType);
1443 }
1444 
1445 mlir::Value acc::KernelsOp::getAsyncValue() {
1446  return getAsyncValue(mlir::acc::DeviceType::None);
1447 }
1448 
1449 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1450  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1451  getAsyncOperands(), deviceType);
1452 }
1453 
1454 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1455  return getNumWorkersValue(mlir::acc::DeviceType::None);
1456 }
1457 
1459 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1460  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1461  deviceType);
1462 }
1463 
1464 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1465  return getVectorLengthValue(mlir::acc::DeviceType::None);
1466 }
1467 
1469 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1470  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1471  getVectorLength(), deviceType);
1472 }
1473 
1474 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1475  return getNumGangsValues(mlir::acc::DeviceType::None);
1476 }
1477 
1479 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1480  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1481  getNumGangsSegments(), deviceType);
1482 }
1483 
1484 bool acc::KernelsOp::hasWaitOnly() {
1485  return hasWaitOnly(mlir::acc::DeviceType::None);
1486 }
1487 
1488 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1489  return hasDeviceType(getWaitOnly(), deviceType);
1490 }
1491 
1492 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1493  return getWaitValues(mlir::acc::DeviceType::None);
1494 }
1495 
1497 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1499  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1500  getHasWaitDevnum(), deviceType);
1501 }
1502 
1503 mlir::Value KernelsOp::getWaitDevnum() {
1504  return getWaitDevnum(mlir::acc::DeviceType::None);
1505 }
1506 
1507 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1508  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1509  getWaitOperandsSegments(), getHasWaitDevnum(),
1510  deviceType);
1511 }
1512 
1515  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1516  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1517  return failure();
1518 
1520  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1521  getWaitOperandsDeviceTypeAttr(), "wait")))
1522  return failure();
1523 
1524  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1525  getNumWorkersDeviceTypeAttr(),
1526  "num_workers")))
1527  return failure();
1528 
1529  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1530  getVectorLengthDeviceTypeAttr(),
1531  "vector_length")))
1532  return failure();
1533 
1534  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1535  getAsyncOperandsDeviceTypeAttr(),
1536  "async")))
1537  return failure();
1538 
1539  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1540  return failure();
1541 
1542  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1543 }
1544 
1545 //===----------------------------------------------------------------------===//
1546 // HostDataOp
1547 //===----------------------------------------------------------------------===//
1548 
1550  if (getDataClauseOperands().empty())
1551  return emitError("at least one operand must appear on the host_data "
1552  "operation");
1553 
1554  for (mlir::Value operand : getDataClauseOperands())
1555  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1556  return emitError("expect data entry operation as defining op");
1557  return success();
1558 }
1559 
1560 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1561  MLIRContext *context) {
1562  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1563 }
1564 
1565 //===----------------------------------------------------------------------===//
1566 // LoopOp
1567 //===----------------------------------------------------------------------===//
1568 
1570  OpAsmParser &parser, llvm::StringRef keyword,
1573  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1574  bool &needCommaBetweenValues, bool &newValue) {
1575  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1576  if (parser.parseEqual())
1577  return failure();
1578  if (parser.parseOperand(operands.emplace_back()) ||
1579  parser.parseColonType(types.emplace_back()))
1580  return failure();
1581  attributes.push_back(gangArgType);
1582  needCommaBetweenValues = true;
1583  newValue = true;
1584  }
1585  return success();
1586 }
1587 
1589  OpAsmParser &parser,
1591  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1592  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1593  mlir::ArrayAttr &gangOnlyDeviceType) {
1594  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1595  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1596  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1598  bool needCommaBetweenValues = false;
1599  bool needCommaBeforeOperands = false;
1600 
1601  if (failed(parser.parseOptionalLParen())) {
1602  // Gang only keyword
1603  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1605  gangOnlyDeviceType =
1606  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1607  return success();
1608  }
1609 
1610  // Parse gang only attributes
1611  if (succeeded(parser.parseOptionalLSquare())) {
1612  // Parse gang only attributes
1613  if (failed(parser.parseCommaSeparatedList([&]() {
1614  if (parser.parseAttribute(
1615  gangOnlyDeviceTypeAttributes.emplace_back()))
1616  return failure();
1617  return success();
1618  })))
1619  return failure();
1620  if (parser.parseRSquare())
1621  return failure();
1622  needCommaBeforeOperands = true;
1623  }
1624 
1625  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1626  mlir::acc::GangArgType::Num);
1627  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1628  mlir::acc::GangArgType::Dim);
1629  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1630  parser.getContext(), mlir::acc::GangArgType::Static);
1631 
1632  do {
1633  if (needCommaBeforeOperands) {
1634  needCommaBeforeOperands = false;
1635  continue;
1636  }
1637 
1638  if (failed(parser.parseLBrace()))
1639  return failure();
1640 
1641  int32_t crtOperandsSize = gangOperands.size();
1642  while (true) {
1643  bool newValue = false;
1644  bool needValue = false;
1645  if (needCommaBetweenValues) {
1646  if (succeeded(parser.parseOptionalComma()))
1647  needValue = true; // expect a new value after comma.
1648  else
1649  break;
1650  }
1651 
1652  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1653  gangOperands, gangOperandsType,
1654  gangArgTypeAttributes, argNum,
1655  needCommaBetweenValues, newValue)))
1656  return failure();
1657  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1658  gangOperands, gangOperandsType,
1659  gangArgTypeAttributes, argDim,
1660  needCommaBetweenValues, newValue)))
1661  return failure();
1662  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1663  gangOperands, gangOperandsType,
1664  gangArgTypeAttributes, argStatic,
1665  needCommaBetweenValues, newValue)))
1666  return failure();
1667 
1668  if (!newValue && needValue) {
1669  parser.emitError(parser.getCurrentLocation(),
1670  "new value expected after comma");
1671  return failure();
1672  }
1673 
1674  if (!newValue)
1675  break;
1676  }
1677 
1678  if (gangOperands.empty())
1679  return parser.emitError(
1680  parser.getCurrentLocation(),
1681  "expect at least one of num, dim or static values");
1682 
1683  if (failed(parser.parseRBrace()))
1684  return failure();
1685 
1686  if (succeeded(parser.parseOptionalLSquare())) {
1687  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1688  parser.parseRSquare())
1689  return failure();
1690  } else {
1691  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1693  }
1694 
1695  seg.push_back(gangOperands.size() - crtOperandsSize);
1696 
1697  } while (succeeded(parser.parseOptionalComma()));
1698 
1699  if (failed(parser.parseRParen()))
1700  return failure();
1701 
1702  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1703  gangArgTypeAttributes.end());
1704  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1705  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1706 
1708  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1709  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1710 
1711  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1712  return success();
1713 }
1714 
1716  mlir::OperandRange operands, mlir::TypeRange types,
1717  std::optional<mlir::ArrayAttr> gangArgTypes,
1718  std::optional<mlir::ArrayAttr> deviceTypes,
1719  std::optional<mlir::DenseI32ArrayAttr> segments,
1720  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1721 
1722  if (operands.begin() == operands.end() &&
1723  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1724  return;
1725  }
1726 
1727  p << "(";
1728 
1729  printDeviceTypes(p, gangOnlyDeviceTypes);
1730 
1731  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1732  hasDeviceTypeValues(deviceTypes))
1733  p << ", ";
1734 
1735  if (hasDeviceTypeValues(deviceTypes)) {
1736  unsigned opIdx = 0;
1737  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1738  p << "{";
1739  llvm::interleaveComma(
1740  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1741  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1742  (*gangArgTypes)[opIdx]);
1743  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1744  p << LoopOp::getGangNumKeyword();
1745  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1746  p << LoopOp::getGangDimKeyword();
1747  else if (gangArgTypeAttr.getValue() ==
1748  mlir::acc::GangArgType::Static)
1749  p << LoopOp::getGangStaticKeyword();
1750  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1751  ++opIdx;
1752  });
1753  p << "}";
1754  printSingleDeviceType(p, it.value());
1755  });
1756  }
1757  p << ")";
1758 }
1759 
1761  std::optional<mlir::ArrayAttr> segments,
1762  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1763  if (!segments)
1764  return false;
1765  for (auto attr : *segments) {
1766  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1767  if (deviceTypes.contains(deviceTypeAttr.getValue()))
1768  return true;
1769  deviceTypes.insert(deviceTypeAttr.getValue());
1770  }
1771  return false;
1772 }
1773 
1774 /// Check for duplicates in the DeviceType array attribute.
1775 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1776  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1777  if (!deviceTypes)
1778  return success();
1779  for (auto attr : deviceTypes) {
1780  auto deviceTypeAttr =
1781  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1782  if (!deviceTypeAttr)
1783  return failure();
1784  if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1785  return failure();
1786  crtDeviceTypes.insert(deviceTypeAttr.getValue());
1787  }
1788  return success();
1789 }
1790 
1792  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1793  (getUpperbound().size() != getInclusiveUpperbound()->size()))
1794  return emitError() << "inclusiveUpperbound size is expected to be the same"
1795  << " as upperbound size";
1796 
1797  // Check collapse
1798  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1799  return emitOpError() << "collapse device_type attr must be define when"
1800  << " collapse attr is present";
1801 
1802  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1803  getCollapseAttr().getValue().size() !=
1804  getCollapseDeviceTypeAttr().getValue().size())
1805  return emitOpError() << "collapse attribute count must match collapse"
1806  << " device_type count";
1807  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1808  return emitOpError()
1809  << "duplicate device_type found in collapseDeviceType attribute";
1810 
1811  // Check gang
1812  if (!getGangOperands().empty()) {
1813  if (!getGangOperandsArgType())
1814  return emitOpError() << "gangOperandsArgType attribute must be defined"
1815  << " when gang operands are present";
1816 
1817  if (getGangOperands().size() !=
1818  getGangOperandsArgTypeAttr().getValue().size())
1819  return emitOpError() << "gangOperandsArgType attribute count must match"
1820  << " gangOperands count";
1821  }
1822  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1823  return emitOpError() << "duplicate device_type found in gang attribute";
1824 
1826  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1827  getGangOperandsDeviceTypeAttr(), "gang")))
1828  return failure();
1829 
1830  // Check worker
1831  if (failed(checkDeviceTypes(getWorkerAttr())))
1832  return emitOpError() << "duplicate device_type found in worker attribute";
1833  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1834  return emitOpError() << "duplicate device_type found in "
1835  "workerNumOperandsDeviceType attribute";
1836  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1837  getWorkerNumOperandsDeviceTypeAttr(),
1838  "worker")))
1839  return failure();
1840 
1841  // Check vector
1842  if (failed(checkDeviceTypes(getVectorAttr())))
1843  return emitOpError() << "duplicate device_type found in vector attribute";
1844  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1845  return emitOpError() << "duplicate device_type found in "
1846  "vectorOperandsDeviceType attribute";
1847  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1848  getVectorOperandsDeviceTypeAttr(),
1849  "vector")))
1850  return failure();
1851 
1853  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1854  getTileOperandsDeviceTypeAttr(), "tile")))
1855  return failure();
1856 
1857  // auto, independent and seq attribute are mutually exclusive.
1858  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1859  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1860  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1861  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1862  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1863  << "\", " << getIndependentAttrName() << ", "
1864  << getSeqAttrName()
1865  << " can be present at the same time";
1866  }
1867 
1868  // Gang, worker and vector are incompatible with seq.
1869  if (getSeqAttr()) {
1870  for (auto attr : getSeqAttr()) {
1871  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1872  if (hasVector(deviceTypeAttr.getValue()) ||
1873  getVectorValue(deviceTypeAttr.getValue()) ||
1874  hasWorker(deviceTypeAttr.getValue()) ||
1875  getWorkerValue(deviceTypeAttr.getValue()) ||
1876  hasGang(deviceTypeAttr.getValue()) ||
1877  getGangValue(mlir::acc::GangArgType::Num,
1878  deviceTypeAttr.getValue()) ||
1879  getGangValue(mlir::acc::GangArgType::Dim,
1880  deviceTypeAttr.getValue()) ||
1881  getGangValue(mlir::acc::GangArgType::Static,
1882  deviceTypeAttr.getValue()))
1883  return emitError()
1884  << "gang, worker or vector cannot appear with the seq attr";
1885  }
1886  }
1887 
1888  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1889  *this, getPrivatizations(), getPrivateOperands(), "private",
1890  "privatizations", false)))
1891  return failure();
1892 
1893  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1894  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1895  "reductions", false)))
1896  return failure();
1897 
1898  if (getCombined().has_value() &&
1899  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1900  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1901  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1902  return emitError("unexpected combined constructs attribute");
1903  }
1904 
1905  // Check non-empty body().
1906  if (getRegion().empty())
1907  return emitError("expected non-empty body.");
1908 
1909  return success();
1910 }
1911 
1912 unsigned LoopOp::getNumDataOperands() {
1913  return getReductionOperands().size() + getPrivateOperands().size();
1914 }
1915 
1916 Value LoopOp::getDataOperand(unsigned i) {
1917  unsigned numOptional =
1918  getLowerbound().size() + getUpperbound().size() + getStep().size();
1919  numOptional += getGangOperands().size();
1920  numOptional += getVectorOperands().size();
1921  numOptional += getWorkerNumOperands().size();
1922  numOptional += getTileOperands().size();
1923  numOptional += getCacheOperands().size();
1924  return getOperand(numOptional + i);
1925 }
1926 
1927 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
1928 
1929 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1930  return hasDeviceType(getAuto_(), deviceType);
1931 }
1932 
1933 bool LoopOp::hasIndependent() {
1934  return hasIndependent(mlir::acc::DeviceType::None);
1935 }
1936 
1937 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1938  return hasDeviceType(getIndependent(), deviceType);
1939 }
1940 
1941 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
1942 
1943 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1944  return hasDeviceType(getSeq(), deviceType);
1945 }
1946 
1947 mlir::Value LoopOp::getVectorValue() {
1948  return getVectorValue(mlir::acc::DeviceType::None);
1949 }
1950 
1951 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1952  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
1953  getVectorOperands(), deviceType);
1954 }
1955 
1956 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
1957 
1958 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1959  return hasDeviceType(getVector(), deviceType);
1960 }
1961 
1962 mlir::Value LoopOp::getWorkerValue() {
1963  return getWorkerValue(mlir::acc::DeviceType::None);
1964 }
1965 
1966 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1967  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
1968  getWorkerNumOperands(), deviceType);
1969 }
1970 
1971 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
1972 
1973 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
1974  return hasDeviceType(getWorker(), deviceType);
1975 }
1976 
1977 mlir::Operation::operand_range LoopOp::getTileValues() {
1978  return getTileValues(mlir::acc::DeviceType::None);
1979 }
1980 
1982 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
1983  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
1984  getTileOperandsSegments(), deviceType);
1985 }
1986 
1987 std::optional<int64_t> LoopOp::getCollapseValue() {
1988  return getCollapseValue(mlir::acc::DeviceType::None);
1989 }
1990 
1991 std::optional<int64_t>
1992 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
1993  if (!getCollapseAttr())
1994  return std::nullopt;
1995  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
1996  auto intAttr =
1997  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
1998  return intAttr.getValue().getZExtValue();
1999  }
2000  return std::nullopt;
2001 }
2002 
2003 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2004  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2005 }
2006 
2007 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2008  mlir::acc::DeviceType deviceType) {
2009  if (getGangOperands().empty())
2010  return {};
2011  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2012  int32_t nbOperandsBefore = 0;
2013  for (unsigned i = 0; i < *pos; ++i)
2014  nbOperandsBefore += (*getGangOperandsSegments())[i];
2016  getGangOperands()
2017  .drop_front(nbOperandsBefore)
2018  .take_front((*getGangOperandsSegments())[*pos]);
2019 
2020  int32_t argTypeIdx = nbOperandsBefore;
2021  for (auto value : values) {
2022  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2023  (*getGangOperandsArgType())[argTypeIdx]);
2024  if (gangArgTypeAttr.getValue() == gangArgType)
2025  return value;
2026  ++argTypeIdx;
2027  }
2028  }
2029  return {};
2030 }
2031 
2032 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2033 
2034 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2035  return hasDeviceType(getGang(), deviceType);
2036 }
2037 
2038 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2039  return {&getRegion()};
2040 }
2041 
2042 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2043 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2044 /// `(` ssa-id-and-type-list `)`
2045 /// region
2049  SmallVectorImpl<Type> &lowerboundType,
2051  SmallVectorImpl<Type> &upperboundType,
2053  SmallVectorImpl<Type> &stepType) {
2054 
2055  SmallVector<OpAsmParser::Argument> inductionVars;
2056  if (succeeded(
2057  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2058  if (parser.parseLParen() ||
2059  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2060  /*allowType=*/true) ||
2061  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2062  parser.parseOperandList(lowerbound, inductionVars.size(),
2064  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2065  parser.parseKeyword("to") || parser.parseLParen() ||
2066  parser.parseOperandList(upperbound, inductionVars.size(),
2068  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2069  parser.parseKeyword("step") || parser.parseLParen() ||
2070  parser.parseOperandList(step, inductionVars.size(),
2072  parser.parseColonTypeList(stepType) || parser.parseRParen())
2073  return failure();
2074  }
2075  return parser.parseRegion(region, inductionVars);
2076 }
2077 
2079  ValueRange lowerbound, TypeRange lowerboundType,
2080  ValueRange upperbound, TypeRange upperboundType,
2081  ValueRange steps, TypeRange stepType) {
2082  ValueRange regionArgs = region.front().getArguments();
2083  if (!regionArgs.empty()) {
2084  p << acc::LoopOp::getControlKeyword() << "(";
2085  llvm::interleaveComma(regionArgs, p,
2086  [&p](Value v) { p << v << " : " << v.getType(); });
2087  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2088  << upperbound << " : " << upperboundType << ") "
2089  << " step (" << steps << " : " << stepType << ") ";
2090  }
2091  p.printRegion(region, /*printEntryBlockArgs=*/false);
2092 }
2093 
2094 //===----------------------------------------------------------------------===//
2095 // DataOp
2096 //===----------------------------------------------------------------------===//
2097 
2099  // 2.6.5. Data Construct restriction
2100  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2101  // attach, or default clause must appear on a data construct.
2102  if (getOperands().empty() && !getDefaultAttr())
2103  return emitError("at least one operand or the default attribute "
2104  "must appear on the data operation");
2105 
2106  for (mlir::Value operand : getDataClauseOperands())
2107  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2108  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2109  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2110  operand.getDefiningOp()))
2111  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2112  "as defining op");
2113 
2114  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2115  return failure();
2116 
2117  return success();
2118 }
2119 
2120 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2121 
2122 Value DataOp::getDataOperand(unsigned i) {
2123  unsigned numOptional = getIfCond() ? 1 : 0;
2124  numOptional += getAsyncOperands().size() ? 1 : 0;
2125  numOptional += getWaitOperands().size();
2126  return getOperand(numOptional + i);
2127 }
2128 
2129 bool acc::DataOp::hasAsyncOnly() {
2130  return hasAsyncOnly(mlir::acc::DeviceType::None);
2131 }
2132 
2133 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2134  return hasDeviceType(getAsyncOnly(), deviceType);
2135 }
2136 
2137 mlir::Value DataOp::getAsyncValue() {
2138  return getAsyncValue(mlir::acc::DeviceType::None);
2139 }
2140 
2141 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2142  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
2143  getAsyncOperands(), deviceType);
2144 }
2145 
2146 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2147 
2148 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2149  return hasDeviceType(getWaitOnly(), deviceType);
2150 }
2151 
2152 mlir::Operation::operand_range DataOp::getWaitValues() {
2153  return getWaitValues(mlir::acc::DeviceType::None);
2154 }
2155 
2157 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2159  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2160  getHasWaitDevnum(), deviceType);
2161 }
2162 
2163 mlir::Value DataOp::getWaitDevnum() {
2164  return getWaitDevnum(mlir::acc::DeviceType::None);
2165 }
2166 
2167 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2168  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2169  getWaitOperandsSegments(), getHasWaitDevnum(),
2170  deviceType);
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // ExitDataOp
2175 //===----------------------------------------------------------------------===//
2176 
2178  // 2.6.6. Data Exit Directive restriction
2179  // At least one copyout, delete, or detach clause must appear on an exit data
2180  // directive.
2181  if (getDataClauseOperands().empty())
2182  return emitError("at least one operand must be present in dataOperands on "
2183  "the exit data operation");
2184 
2185  // The async attribute represent the async clause without value. Therefore the
2186  // attribute and operand cannot appear at the same time.
2187  if (getAsyncOperand() && getAsync())
2188  return emitError("async attribute cannot appear with asyncOperand");
2189 
2190  // The wait attribute represent the wait clause without values. Therefore the
2191  // attribute and operands cannot appear at the same time.
2192  if (!getWaitOperands().empty() && getWait())
2193  return emitError("wait attribute cannot appear with waitOperands");
2194 
2195  if (getWaitDevnum() && getWaitOperands().empty())
2196  return emitError("wait_devnum cannot appear without waitOperands");
2197 
2198  return success();
2199 }
2200 
2201 unsigned ExitDataOp::getNumDataOperands() {
2202  return getDataClauseOperands().size();
2203 }
2204 
2205 Value ExitDataOp::getDataOperand(unsigned i) {
2206  unsigned numOptional = getIfCond() ? 1 : 0;
2207  numOptional += getAsyncOperand() ? 1 : 0;
2208  numOptional += getWaitDevnum() ? 1 : 0;
2209  return getOperand(getWaitOperands().size() + numOptional + i);
2210 }
2211 
2212 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2213  MLIRContext *context) {
2214  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2215 }
2216 
2217 //===----------------------------------------------------------------------===//
2218 // EnterDataOp
2219 //===----------------------------------------------------------------------===//
2220 
2222  // 2.6.6. Data Enter Directive restriction
2223  // At least one copyin, create, or attach clause must appear on an enter data
2224  // directive.
2225  if (getDataClauseOperands().empty())
2226  return emitError("at least one operand must be present in dataOperands on "
2227  "the enter data operation");
2228 
2229  // The async attribute represent the async clause without value. Therefore the
2230  // attribute and operand cannot appear at the same time.
2231  if (getAsyncOperand() && getAsync())
2232  return emitError("async attribute cannot appear with asyncOperand");
2233 
2234  // The wait attribute represent the wait clause without values. Therefore the
2235  // attribute and operands cannot appear at the same time.
2236  if (!getWaitOperands().empty() && getWait())
2237  return emitError("wait attribute cannot appear with waitOperands");
2238 
2239  if (getWaitDevnum() && getWaitOperands().empty())
2240  return emitError("wait_devnum cannot appear without waitOperands");
2241 
2242  for (mlir::Value operand : getDataClauseOperands())
2243  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2244  operand.getDefiningOp()))
2245  return emitError("expect data entry operation as defining op");
2246 
2247  return success();
2248 }
2249 
2250 unsigned EnterDataOp::getNumDataOperands() {
2251  return getDataClauseOperands().size();
2252 }
2253 
2254 Value EnterDataOp::getDataOperand(unsigned i) {
2255  unsigned numOptional = getIfCond() ? 1 : 0;
2256  numOptional += getAsyncOperand() ? 1 : 0;
2257  numOptional += getWaitDevnum() ? 1 : 0;
2258  return getOperand(getWaitOperands().size() + numOptional + i);
2259 }
2260 
2261 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2262  MLIRContext *context) {
2263  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2264 }
2265 
2266 //===----------------------------------------------------------------------===//
2267 // AtomicReadOp
2268 //===----------------------------------------------------------------------===//
2269 
2270 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2271 
2272 //===----------------------------------------------------------------------===//
2273 // AtomicWriteOp
2274 //===----------------------------------------------------------------------===//
2275 
2276 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2277 
2278 //===----------------------------------------------------------------------===//
2279 // AtomicUpdateOp
2280 //===----------------------------------------------------------------------===//
2281 
2282 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2283  PatternRewriter &rewriter) {
2284  if (op.isNoOp()) {
2285  rewriter.eraseOp(op);
2286  return success();
2287  }
2288 
2289  if (Value writeVal = op.getWriteOpVal()) {
2290  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2291  return success();
2292  }
2293 
2294  return failure();
2295 }
2296 
2297 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2298 
2299 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2300 
2301 //===----------------------------------------------------------------------===//
2302 // AtomicCaptureOp
2303 //===----------------------------------------------------------------------===//
2304 
2305 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2306  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2307  return op;
2308  return dyn_cast<AtomicReadOp>(getSecondOp());
2309 }
2310 
2311 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2312  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2313  return op;
2314  return dyn_cast<AtomicWriteOp>(getSecondOp());
2315 }
2316 
2317 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2318  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2319  return op;
2320  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2321 }
2322 
2323 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2324 
2325 //===----------------------------------------------------------------------===//
2326 // DeclareEnterOp
2327 //===----------------------------------------------------------------------===//
2328 
2329 template <typename Op>
2330 static LogicalResult
2332  bool requireAtLeastOneOperand = true) {
2333  if (operands.empty() && requireAtLeastOneOperand)
2334  return emitError(
2335  op->getLoc(),
2336  "at least one operand must appear on the declare operation");
2337 
2338  for (mlir::Value operand : operands) {
2339  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2340  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2341  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2342  operand.getDefiningOp()))
2343  return op.emitError(
2344  "expect valid declare data entry operation or acc.getdeviceptr "
2345  "as defining op");
2346 
2347  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2348  assert(varPtr && "declare operands can only be data entry operations which "
2349  "must have varPtr");
2350  std::optional<mlir::acc::DataClause> dataClauseOptional{
2351  getDataClause(operand.getDefiningOp())};
2352  assert(dataClauseOptional.has_value() &&
2353  "declare operands can only be data entry operations which must have "
2354  "dataClause");
2355 
2356  // If varPtr has no defining op - there is nothing to check further.
2357  if (!varPtr.getDefiningOp())
2358  continue;
2359 
2360  // Check that the varPtr has a declare attribute.
2361  auto declareAttribute{
2362  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2363  if (!declareAttribute)
2364  return op.emitError(
2365  "expect declare attribute on variable in declare operation");
2366 
2367  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2368  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2369  return op.emitError(
2370  "expect matching declare attribute on variable in declare operation");
2371 
2372  // If the variable is marked with implicit attribute, the matching declare
2373  // data action must also be marked implicit. The reverse is not checked
2374  // since implicit data action may be inserted to do actions like updating
2375  // device copy, in which case the variable is not necessarily implicitly
2376  // declare'd.
2377  if (declAttr.getImplicit() &&
2378  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2379  return op.emitError(
2380  "implicitness must match between declare op and flag on variable");
2381  }
2382 
2383  return success();
2384 }
2385 
2387  return checkDeclareOperands(*this, this->getDataClauseOperands());
2388 }
2389 
2390 //===----------------------------------------------------------------------===//
2391 // DeclareExitOp
2392 //===----------------------------------------------------------------------===//
2393 
2395  if (getToken())
2396  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2397  /*requireAtLeastOneOperand=*/false);
2398  return checkDeclareOperands(*this, this->getDataClauseOperands());
2399 }
2400 
2401 //===----------------------------------------------------------------------===//
2402 // DeclareOp
2403 //===----------------------------------------------------------------------===//
2404 
2406  return checkDeclareOperands(*this, this->getDataClauseOperands());
2407 }
2408 
2409 //===----------------------------------------------------------------------===//
2410 // RoutineOp
2411 //===----------------------------------------------------------------------===//
2412 
2413 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2414  acc::DeviceType dtype) {
2415  unsigned parallelism = 0;
2416  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2417  parallelism += op.hasWorker(dtype) ? 1 : 0;
2418  parallelism += op.hasVector(dtype) ? 1 : 0;
2419  parallelism += op.hasSeq(dtype) ? 1 : 0;
2420  return parallelism;
2421 }
2422 
2424  unsigned baseParallelism =
2426 
2427  if (baseParallelism > 1)
2428  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2429  "be present at the same time";
2430 
2431  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2432  ++dtypeInt) {
2433  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2434  if (dtype == acc::DeviceType::None)
2435  continue;
2436  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2437 
2438  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2439  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2440  "be present at the same time";
2441  }
2442 
2443  return success();
2444 }
2445 
2446 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2447  mlir::ArrayAttr &deviceTypes) {
2448  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2449  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2450 
2451  if (failed(parser.parseCommaSeparatedList([&]() {
2452  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2453  return failure();
2454  if (failed(parser.parseOptionalLSquare())) {
2455  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2456  parser.getContext(), mlir::acc::DeviceType::None));
2457  } else {
2458  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2459  parser.parseRSquare())
2460  return failure();
2461  }
2462  return success();
2463  })))
2464  return failure();
2465 
2466  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2467  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2468 
2469  return success();
2470 }
2471 
2473  std::optional<mlir::ArrayAttr> bindName,
2474  std::optional<mlir::ArrayAttr> deviceTypes) {
2475  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2476  [&](const auto &pair) {
2477  p << std::get<0>(pair);
2478  printSingleDeviceType(p, std::get<1>(pair));
2479  });
2480 }
2481 
2483  mlir::ArrayAttr &gang,
2484  mlir::ArrayAttr &gangDim,
2485  mlir::ArrayAttr &gangDimDeviceTypes) {
2486 
2487  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2488  gangDimDeviceTypeAttrs;
2489  bool needCommaBeforeOperands = false;
2490 
2491  // Gang keyword only
2492  if (failed(parser.parseOptionalLParen())) {
2493  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2495  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2496  return success();
2497  }
2498 
2499  // Parse keyword only attributes
2500  if (succeeded(parser.parseOptionalLSquare())) {
2501  if (failed(parser.parseCommaSeparatedList([&]() {
2502  if (parser.parseAttribute(gangAttrs.emplace_back()))
2503  return failure();
2504  return success();
2505  })))
2506  return failure();
2507  if (parser.parseRSquare())
2508  return failure();
2509  needCommaBeforeOperands = true;
2510  }
2511 
2512  if (needCommaBeforeOperands && failed(parser.parseComma()))
2513  return failure();
2514 
2515  if (failed(parser.parseCommaSeparatedList([&]() {
2516  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2517  parser.parseColon() ||
2518  parser.parseAttribute(gangDimAttrs.emplace_back()))
2519  return failure();
2520  if (succeeded(parser.parseOptionalLSquare())) {
2521  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2522  parser.parseRSquare())
2523  return failure();
2524  } else {
2525  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2526  parser.getContext(), mlir::acc::DeviceType::None));
2527  }
2528  return success();
2529  })))
2530  return failure();
2531 
2532  if (failed(parser.parseRParen()))
2533  return failure();
2534 
2535  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2536  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2537  gangDimDeviceTypes =
2538  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2539 
2540  return success();
2541 }
2542 
2544  std::optional<mlir::ArrayAttr> gang,
2545  std::optional<mlir::ArrayAttr> gangDim,
2546  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2547 
2548  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2549  gang->size() == 1) {
2550  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2551  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2552  return;
2553  }
2554 
2555  p << "(";
2556 
2557  printDeviceTypes(p, gang);
2558 
2559  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2560  p << ", ";
2561 
2562  if (hasDeviceTypeValues(gangDimDeviceTypes))
2563  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2564  [&](const auto &pair) {
2565  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2566  p << std::get<0>(pair);
2567  printSingleDeviceType(p, std::get<1>(pair));
2568  });
2569 
2570  p << ")";
2571 }
2572 
2574  mlir::ArrayAttr &deviceTypes) {
2576  // Keyword only
2577  if (failed(parser.parseOptionalLParen())) {
2578  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2580  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2581  return success();
2582  }
2583 
2584  // Parse device type attributes
2585  if (succeeded(parser.parseOptionalLSquare())) {
2586  if (failed(parser.parseCommaSeparatedList([&]() {
2587  if (parser.parseAttribute(attributes.emplace_back()))
2588  return failure();
2589  return success();
2590  })))
2591  return failure();
2592  if (parser.parseRSquare() || parser.parseRParen())
2593  return failure();
2594  }
2595  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2596  return success();
2597 }
2598 
2599 static void
2601  std::optional<mlir::ArrayAttr> deviceTypes) {
2602 
2603  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2604  auto deviceTypeAttr =
2605  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2606  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2607  return;
2608  }
2609 
2610  if (!hasDeviceTypeValues(deviceTypes))
2611  return;
2612 
2613  p << "([";
2614  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2615  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2616  p << dTypeAttr;
2617  });
2618  p << "])";
2619 }
2620 
2621 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2622 
2623 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2624  return hasDeviceType(getWorker(), deviceType);
2625 }
2626 
2627 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2628 
2629 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2630  return hasDeviceType(getVector(), deviceType);
2631 }
2632 
2633 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2634 
2635 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2636  return hasDeviceType(getSeq(), deviceType);
2637 }
2638 
2639 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2640  return getBindNameValue(mlir::acc::DeviceType::None);
2641 }
2642 
2643 std::optional<llvm::StringRef>
2644 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2645  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2646  return std::nullopt;
2647  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2648  auto attr = (*getBindName())[*pos];
2649  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2650  return stringAttr.getValue();
2651  }
2652  return std::nullopt;
2653 }
2654 
2655 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2656 
2657 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2658  return hasDeviceType(getGang(), deviceType);
2659 }
2660 
2661 std::optional<int64_t> RoutineOp::getGangDimValue() {
2662  return getGangDimValue(mlir::acc::DeviceType::None);
2663 }
2664 
2665 std::optional<int64_t>
2666 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2667  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2668  return std::nullopt;
2669  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2670  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2671  return intAttr.getInt();
2672  }
2673  return std::nullopt;
2674 }
2675 
2676 //===----------------------------------------------------------------------===//
2677 // InitOp
2678 //===----------------------------------------------------------------------===//
2679 
2681  Operation *currOp = *this;
2682  while ((currOp = currOp->getParentOp()))
2683  if (isComputeOperation(currOp))
2684  return emitOpError("cannot be nested in a compute operation");
2685  return success();
2686 }
2687 
2688 //===----------------------------------------------------------------------===//
2689 // ShutdownOp
2690 //===----------------------------------------------------------------------===//
2691 
2693  Operation *currOp = *this;
2694  while ((currOp = currOp->getParentOp()))
2695  if (isComputeOperation(currOp))
2696  return emitOpError("cannot be nested in a compute operation");
2697  return success();
2698 }
2699 
2700 //===----------------------------------------------------------------------===//
2701 // SetOp
2702 //===----------------------------------------------------------------------===//
2703 
2705  Operation *currOp = *this;
2706  while ((currOp = currOp->getParentOp()))
2707  if (isComputeOperation(currOp))
2708  return emitOpError("cannot be nested in a compute operation");
2709  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2710  return emitOpError("at least one default_async, device_num, or device_type "
2711  "operand must appear");
2712  return success();
2713 }
2714 
2715 //===----------------------------------------------------------------------===//
2716 // UpdateOp
2717 //===----------------------------------------------------------------------===//
2718 
2720  // At least one of host or device should have a value.
2721  if (getDataClauseOperands().empty())
2722  return emitError("at least one value must be present in dataOperands");
2723 
2724  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2725  getAsyncOperandsDeviceTypeAttr(),
2726  "async")))
2727  return failure();
2728 
2730  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2731  getWaitOperandsDeviceTypeAttr(), "wait")))
2732  return failure();
2733 
2734  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2735  return failure();
2736 
2737  for (mlir::Value operand : getDataClauseOperands())
2738  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2739  operand.getDefiningOp()))
2740  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2741  "as defining op");
2742 
2743  return success();
2744 }
2745 
2746 unsigned UpdateOp::getNumDataOperands() {
2747  return getDataClauseOperands().size();
2748 }
2749 
2750 Value UpdateOp::getDataOperand(unsigned i) {
2751  unsigned numOptional = getAsyncOperands().size();
2752  numOptional += getIfCond() ? 1 : 0;
2753  return getOperand(getWaitOperands().size() + numOptional + i);
2754 }
2755 
2756 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2757  MLIRContext *context) {
2758  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2759 }
2760 
2761 bool UpdateOp::hasAsyncOnly() {
2762  return hasAsyncOnly(mlir::acc::DeviceType::None);
2763 }
2764 
2765 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2766  return hasDeviceType(getAsync(), deviceType);
2767 }
2768 
2769 mlir::Value UpdateOp::getAsyncValue() {
2770  return getAsyncValue(mlir::acc::DeviceType::None);
2771 }
2772 
2773 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2774  if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
2775  return {};
2776 
2777  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2778  return getAsyncOperands()[*pos];
2779 
2780  return {};
2781 }
2782 
2783 bool UpdateOp::hasWaitOnly() {
2784  return hasWaitOnly(mlir::acc::DeviceType::None);
2785 }
2786 
2787 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2788  return hasDeviceType(getWaitOnly(), deviceType);
2789 }
2790 
2791 mlir::Operation::operand_range UpdateOp::getWaitValues() {
2792  return getWaitValues(mlir::acc::DeviceType::None);
2793 }
2794 
2796 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2798  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2799  getHasWaitDevnum(), deviceType);
2800 }
2801 
2802 mlir::Value UpdateOp::getWaitDevnum() {
2803  return getWaitDevnum(mlir::acc::DeviceType::None);
2804 }
2805 
2806 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2807  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2808  getWaitOperandsSegments(), getHasWaitDevnum(),
2809  deviceType);
2810 }
2811 
2812 //===----------------------------------------------------------------------===//
2813 // WaitOp
2814 //===----------------------------------------------------------------------===//
2815 
2817  // The async attribute represent the async clause without value. Therefore the
2818  // attribute and operand cannot appear at the same time.
2819  if (getAsyncOperand() && getAsync())
2820  return emitError("async attribute cannot appear with asyncOperand");
2821 
2822  if (getWaitDevnum() && getWaitOperands().empty())
2823  return emitError("wait_devnum cannot appear without waitOperands");
2824 
2825  return success();
2826 }
2827 
2828 #define GET_OP_CLASSES
2829 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2830 
2831 #define GET_ATTRDEF_CLASSES
2832 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2833 
2834 #define GET_TYPEDEF_CLASSES
2835 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2836 
2837 //===----------------------------------------------------------------------===//
2838 // acc dialect utilities
2839 //===----------------------------------------------------------------------===//
2840 
2842  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2843  .Case<ACC_DATA_ENTRY_OPS>(
2844  [&](auto entry) { return entry.getVarPtr(); })
2845  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2846  [&](auto exit) { return exit.getVarPtr(); })
2847  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2848  return varPtr;
2849 }
2850 
2852  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2854  [&](auto dataClause) { return dataClause.getAccPtr(); })
2855  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2856  return accPtr;
2857 }
2858 
2860  auto varPtrPtr{
2862  .Case<ACC_DATA_ENTRY_OPS>(
2863  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2864  .Default([&](mlir::Operation *) { return mlir::Value(); })};
2865  return varPtrPtr;
2866 }
2867 
2872  accDataClauseOp)
2873  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2875  dataClause.getBounds().begin(), dataClause.getBounds().end());
2876  })
2877  .Default([&](mlir::Operation *) {
2879  })};
2880  return bounds;
2881 }
2882 
2883 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2884  auto name{
2886  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2887  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2888  return {};
2889  })};
2890  return name;
2891 }
2892 
2893 std::optional<mlir::acc::DataClause>
2895  auto dataClause{
2897  accDataEntryOp)
2898  .Case<ACC_DATA_ENTRY_OPS>(
2899  [&](auto entry) { return entry.getDataClause(); })
2900  .Default([&](mlir::Operation *) { return std::nullopt; })};
2901  return dataClause;
2902 }
2903 
2905  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
2906  .Case<ACC_DATA_ENTRY_OPS>(
2907  [&](auto entry) { return entry.getImplicit(); })
2908  .Default([&](mlir::Operation *) { return false; })};
2909  return implicit;
2910 }
2911 
2913  auto dataOperands{
2916  [&](auto entry) { return entry.getDataClauseOperands(); })
2917  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
2918  return dataOperands;
2919 }
2920 
2923  auto dataOperands{
2926  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
2927  .Default([&](mlir::Operation *) { return nullptr; })};
2928  return dataOperands;
2929 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1972
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2543
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:445
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1760
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:747
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:1775
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:459
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1122
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2446
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1133
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1038
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:76
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2600
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1569
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1287
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2331
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:96
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2047
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:676
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1167
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:822
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:120
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:908
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2078
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2573
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2482
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1021
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1194
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2472
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:975
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:660
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:152
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1588
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:535
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:952
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:107
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:691
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1267
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:82
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1715
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:136
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1205
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:172
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:757
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2413
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:958
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1313
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:640
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:66
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:42
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:50
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
Definition: OpenACC.cpp:2841
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:2894
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:2922
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:2869
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:2912
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:2883
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:2904
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:2859
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
Definition: OpenACC.cpp:2851
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:124
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.