MLIR  18.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace acc;
23 
24 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
25 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
28 
29 namespace {
30 struct MemRefPointerLikeModel
31  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
32  MemRefType> {
33  Type getElementType(Type pointer) const {
34  return llvm::cast<MemRefType>(pointer).getElementType();
35  }
36 };
37 
38 struct LLVMPointerPointerLikeModel
39  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
40  LLVM::LLVMPointerType> {
41  Type getElementType(Type pointer) const { return Type(); }
42 };
43 } // namespace
44 
45 //===----------------------------------------------------------------------===//
46 // OpenACC operations
47 //===----------------------------------------------------------------------===//
48 
49 void OpenACCDialect::initialize() {
50  addOperations<
51 #define GET_OP_LIST
52 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
53  >();
54  addAttributes<
55 #define GET_ATTRDEF_LIST
56 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
57  >();
58  addTypes<
59 #define GET_TYPEDEF_LIST
60 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
61  >();
62 
63  // By attaching interfaces here, we make the OpenACC dialect dependent on
64  // the other dialects. This is probably better than having dialects like LLVM
65  // and memref be dependent on OpenACC.
66  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
67  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
68  *getContext());
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // DataBoundsOp
73 //===----------------------------------------------------------------------===//
75  auto extent = getExtent();
76  auto upperbound = getUpperbound();
77  if (!extent && !upperbound)
78  return emitError("expected extent or upperbound.");
79  return success();
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // PrivateOp
84 //===----------------------------------------------------------------------===//
86  if (getDataClause() != acc::DataClause::acc_private)
87  return emitError(
88  "data clause associated with private operation must match its intent");
89  return success();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // FirstprivateOp
94 //===----------------------------------------------------------------------===//
96  if (getDataClause() != acc::DataClause::acc_firstprivate)
97  return emitError("data clause associated with firstprivate operation must "
98  "match its intent");
99  return success();
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // ReductionOp
104 //===----------------------------------------------------------------------===//
106  if (getDataClause() != acc::DataClause::acc_reduction)
107  return emitError("data clause associated with reduction operation must "
108  "match its intent");
109  return success();
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // DevicePtrOp
114 //===----------------------------------------------------------------------===//
116  if (getDataClause() != acc::DataClause::acc_deviceptr)
117  return emitError("data clause associated with deviceptr operation must "
118  "match its intent");
119  return success();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // PresentOp
124 //===----------------------------------------------------------------------===//
126  if (getDataClause() != acc::DataClause::acc_present)
127  return emitError(
128  "data clause associated with present operation must match its intent");
129  return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // CopyinOp
134 //===----------------------------------------------------------------------===//
136  // Test for all clauses this operation can be decomposed from:
137  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
138  getDataClause() != acc::DataClause::acc_copyin_readonly &&
139  getDataClause() != acc::DataClause::acc_copy &&
140  getDataClause() != acc::DataClause::acc_reduction)
141  return emitError(
142  "data clause associated with copyin operation must match its intent"
143  " or specify original clause this operation was decomposed from");
144  return success();
145 }
146 
147 bool acc::CopyinOp::isCopyinReadonly() {
148  return getDataClause() == acc::DataClause::acc_copyin_readonly;
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // CreateOp
153 //===----------------------------------------------------------------------===//
155  // Test for all clauses this operation can be decomposed from:
156  if (getDataClause() != acc::DataClause::acc_create &&
157  getDataClause() != acc::DataClause::acc_create_zero &&
158  getDataClause() != acc::DataClause::acc_copyout &&
159  getDataClause() != acc::DataClause::acc_copyout_zero)
160  return emitError(
161  "data clause associated with create operation must match its intent"
162  " or specify original clause this operation was decomposed from");
163  return success();
164 }
165 
166 bool acc::CreateOp::isCreateZero() {
167  // The zero modifier is encoded in the data clause.
168  return getDataClause() == acc::DataClause::acc_create_zero ||
169  getDataClause() == acc::DataClause::acc_copyout_zero;
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // NoCreateOp
174 //===----------------------------------------------------------------------===//
176  if (getDataClause() != acc::DataClause::acc_no_create)
177  return emitError("data clause associated with no_create operation must "
178  "match its intent");
179  return success();
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // AttachOp
184 //===----------------------------------------------------------------------===//
186  if (getDataClause() != acc::DataClause::acc_attach)
187  return emitError(
188  "data clause associated with attach operation must match its intent");
189  return success();
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // DeclareDeviceResidentOp
194 //===----------------------------------------------------------------------===//
195 
197  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
198  return emitError("data clause associated with device_resident operation "
199  "must match its intent");
200  return success();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // DeclareLinkOp
205 //===----------------------------------------------------------------------===//
206 
208  if (getDataClause() != acc::DataClause::acc_declare_link)
209  return emitError(
210  "data clause associated with link operation must match its intent");
211  return success();
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // CopyoutOp
216 //===----------------------------------------------------------------------===//
218  // Test for all clauses this operation can be decomposed from:
219  if (getDataClause() != acc::DataClause::acc_copyout &&
220  getDataClause() != acc::DataClause::acc_copyout_zero &&
221  getDataClause() != acc::DataClause::acc_copy &&
222  getDataClause() != acc::DataClause::acc_reduction)
223  return emitError(
224  "data clause associated with copyout operation must match its intent"
225  " or specify original clause this operation was decomposed from");
226  if (!getVarPtr() || !getAccPtr())
227  return emitError("must have both host and device pointers");
228  return success();
229 }
230 
231 bool acc::CopyoutOp::isCopyoutZero() {
232  return getDataClause() == acc::DataClause::acc_copyout_zero;
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // DeleteOp
237 //===----------------------------------------------------------------------===//
239  // Test for all clauses this operation can be decomposed from:
240  if (getDataClause() != acc::DataClause::acc_delete &&
241  getDataClause() != acc::DataClause::acc_create &&
242  getDataClause() != acc::DataClause::acc_create_zero &&
243  getDataClause() != acc::DataClause::acc_copyin &&
244  getDataClause() != acc::DataClause::acc_copyin_readonly &&
245  getDataClause() != acc::DataClause::acc_present &&
246  getDataClause() != acc::DataClause::acc_declare_device_resident &&
247  getDataClause() != acc::DataClause::acc_declare_link)
248  return emitError(
249  "data clause associated with delete operation must match its intent"
250  " or specify original clause this operation was decomposed from");
251  if (!getVarPtr() && !getAccPtr())
252  return emitError("must have either host or device pointer");
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DetachOp
258 //===----------------------------------------------------------------------===//
260  // Test for all clauses this operation can be decomposed from:
261  if (getDataClause() != acc::DataClause::acc_detach &&
262  getDataClause() != acc::DataClause::acc_attach)
263  return emitError(
264  "data clause associated with detach operation must match its intent"
265  " or specify original clause this operation was decomposed from");
266  if (!getVarPtr() && !getAccPtr())
267  return emitError("must have either host or device pointer");
268  return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // HostOp
273 //===----------------------------------------------------------------------===//
275  // Test for all clauses this operation can be decomposed from:
276  if (getDataClause() != acc::DataClause::acc_update_host &&
277  getDataClause() != acc::DataClause::acc_update_self)
278  return emitError(
279  "data clause associated with host operation must match its intent"
280  " or specify original clause this operation was decomposed from");
281  if (!getVarPtr() || !getAccPtr())
282  return emitError("must have both host and device pointers");
283  return success();
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // DeviceOp
288 //===----------------------------------------------------------------------===//
290  // Test for all clauses this operation can be decomposed from:
291  if (getDataClause() != acc::DataClause::acc_update_device)
292  return emitError(
293  "data clause associated with device operation must match its intent"
294  " or specify original clause this operation was decomposed from");
295  return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // UseDeviceOp
300 //===----------------------------------------------------------------------===//
302  // Test for all clauses this operation can be decomposed from:
303  if (getDataClause() != acc::DataClause::acc_use_device)
304  return emitError(
305  "data clause associated with use_device operation must match its intent"
306  " or specify original clause this operation was decomposed from");
307  return success();
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // CacheOp
312 //===----------------------------------------------------------------------===//
314  // Test for all clauses this operation can be decomposed from:
315  if (getDataClause() != acc::DataClause::acc_cache &&
316  getDataClause() != acc::DataClause::acc_cache_readonly)
317  return emitError(
318  "data clause associated with cache operation must match its intent"
319  " or specify original clause this operation was decomposed from");
320  return success();
321 }
322 
323 template <typename StructureOp>
325  unsigned nRegions = 1) {
326 
327  SmallVector<Region *, 2> regions;
328  for (unsigned i = 0; i < nRegions; ++i)
329  regions.push_back(state.addRegion());
330 
331  for (Region *region : regions)
332  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
333  return failure();
334 
335  return success();
336 }
337 
338 static bool isComputeOperation(Operation *op) {
339  return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
340 }
341 
342 namespace {
343 /// Pattern to remove operation without region that have constant false `ifCond`
344 /// and remove the condition from the operation if the `ifCond` is a true
345 /// constant.
346 template <typename OpTy>
347 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
349 
350  LogicalResult matchAndRewrite(OpTy op,
351  PatternRewriter &rewriter) const override {
352  // Early return if there is no condition.
353  Value ifCond = op.getIfCond();
354  if (!ifCond)
355  return failure();
356 
357  IntegerAttr constAttr;
358  if (!matchPattern(ifCond, m_Constant(&constAttr)))
359  return failure();
360  if (constAttr.getInt())
361  rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
362  else
363  rewriter.eraseOp(op);
364 
365  return success();
366  }
367 };
368 
369 /// Replaces the given op with the contents of the given single-block region,
370 /// using the operands of the block terminator to replace operation results.
371 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
372  Region &region, ValueRange blockArgs = {}) {
373  assert(llvm::hasSingleElement(region) && "expected single-region block");
374  Block *block = &region.front();
375  Operation *terminator = block->getTerminator();
376  ValueRange results = terminator->getOperands();
377  rewriter.inlineBlockBefore(block, op, blockArgs);
378  rewriter.replaceOp(op, results);
379  rewriter.eraseOp(terminator);
380 }
381 
382 /// Pattern to remove operation with region that have constant false `ifCond`
383 /// and remove the condition from the operation if the `ifCond` is constant
384 /// true.
385 template <typename OpTy>
386 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
388 
389  LogicalResult matchAndRewrite(OpTy op,
390  PatternRewriter &rewriter) const override {
391  // Early return if there is no condition.
392  Value ifCond = op.getIfCond();
393  if (!ifCond)
394  return failure();
395 
396  IntegerAttr constAttr;
397  if (!matchPattern(ifCond, m_Constant(&constAttr)))
398  return failure();
399  if (constAttr.getInt())
400  rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
401  else
402  replaceOpWithRegion(rewriter, op, op.getRegion());
403 
404  return success();
405  }
406 };
407 
408 } // namespace
409 
410 //===----------------------------------------------------------------------===//
411 // PrivateRecipeOp
412 //===----------------------------------------------------------------------===//
413 
415  Operation *op, Region &region, StringRef regionType, StringRef regionName,
416  Type type, bool verifyYield, bool optional = false) {
417  if (optional && region.empty())
418  return success();
419 
420  if (region.empty())
421  return op->emitOpError() << "expects non-empty " << regionName << " region";
422  Block &firstBlock = region.front();
423  if (firstBlock.getNumArguments() < 1 ||
424  firstBlock.getArgument(0).getType() != type)
425  return op->emitOpError() << "expects " << regionName
426  << " region first "
427  "argument of the "
428  << regionType << " type";
429 
430  if (verifyYield) {
431  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
432  if (yieldOp.getOperands().size() != 1 ||
433  yieldOp.getOperands().getTypes()[0] != type)
434  return op->emitOpError() << "expects " << regionName
435  << " region to "
436  "yield a value of the "
437  << regionType << " type";
438  }
439  }
440  return success();
441 }
442 
443 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
444  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
445  "privatization", "init", getType(),
446  /*verifyYield=*/false)))
447  return failure();
449  *this, getDestroyRegion(), "privatization", "destroy", getType(),
450  /*verifyYield=*/false, /*optional=*/true)))
451  return failure();
452  return success();
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // FirstprivateRecipeOp
457 //===----------------------------------------------------------------------===//
458 
459 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
460  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
461  "privatization", "init", getType(),
462  /*verifyYield=*/false)))
463  return failure();
464 
465  if (getCopyRegion().empty())
466  return emitOpError() << "expects non-empty copy region";
467 
468  Block &firstBlock = getCopyRegion().front();
469  if (firstBlock.getNumArguments() < 2 ||
470  firstBlock.getArgument(0).getType() != getType())
471  return emitOpError() << "expects copy region with two arguments of the "
472  "privatization type";
473 
474  if (getDestroyRegion().empty())
475  return success();
476 
477  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
478  "privatization", "destroy",
479  getType(), /*verifyYield=*/false)))
480  return failure();
481 
482  return success();
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // ReductionRecipeOp
487 //===----------------------------------------------------------------------===//
488 
489 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
490  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
491  "init", getType(),
492  /*verifyYield=*/false)))
493  return failure();
494 
495  if (getCombinerRegion().empty())
496  return emitOpError() << "expects non-empty combiner region";
497 
498  Block &reductionBlock = getCombinerRegion().front();
499  if (reductionBlock.getNumArguments() < 2 ||
500  reductionBlock.getArgument(0).getType() != getType() ||
501  reductionBlock.getArgument(1).getType() != getType())
502  return emitOpError() << "expects combiner region with the first two "
503  << "arguments of the reduction type";
504 
505  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
506  if (yieldOp.getOperands().size() != 1 ||
507  yieldOp.getOperands().getTypes()[0] != getType())
508  return emitOpError() << "expects combiner region to yield a value "
509  "of the reduction type";
510  }
511 
512  return success();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // Custom parser and printer verifier for private clause
517 //===----------------------------------------------------------------------===//
518 
520  mlir::OpAsmParser &parser,
522  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
524  if (failed(parser.parseCommaSeparatedList([&]() {
525  if (parser.parseAttribute(attributes.emplace_back()) ||
526  parser.parseArrow() ||
527  parser.parseOperand(operands.emplace_back()) ||
528  parser.parseColonType(types.emplace_back()))
529  return failure();
530  return success();
531  })))
532  return failure();
533  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
534  attributes.end());
535  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
536  return success();
537 }
538 
540  mlir::OperandRange operands,
541  mlir::TypeRange types,
542  std::optional<mlir::ArrayAttr> attributes) {
543  for (unsigned i = 0, e = attributes->size(); i < e; ++i) {
544  if (i != 0)
545  p << ", ";
546  p << (*attributes)[i] << " -> " << operands[i] << " : "
547  << operands[i].getType();
548  }
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // ParallelOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
556 template <typename Op>
558  const mlir::ValueRange &operands) {
559  for (mlir::Value operand : operands)
560  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
561  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
562  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
563  operand.getDefiningOp()))
564  return op.emitError(
565  "expect data entry/exit operation or acc.getdeviceptr "
566  "as defining op");
567  return success();
568 }
569 
570 template <typename Op>
571 static LogicalResult
572 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
573  mlir::OperandRange operands, llvm::StringRef operandName,
574  llvm::StringRef symbolName, bool checkOperandType = true) {
575  if (!operands.empty()) {
576  if (!attributes || attributes->size() != operands.size())
577  return op->emitOpError()
578  << "expected as many " << symbolName << " symbol reference as "
579  << operandName << " operands";
580  } else {
581  if (attributes)
582  return op->emitOpError()
583  << "unexpected " << symbolName << " symbol reference";
584  return success();
585  }
586 
588  for (auto args : llvm::zip(operands, *attributes)) {
589  mlir::Value operand = std::get<0>(args);
590 
591  if (!set.insert(operand).second)
592  return op->emitOpError()
593  << operandName << " operand appears more than once";
594 
595  mlir::Type varType = operand.getType();
596  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
597  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
598  if (!decl)
599  return op->emitOpError()
600  << "expected symbol reference " << symbolRef << " to point to a "
601  << operandName << " declaration";
602 
603  if (checkOperandType && decl.getType() && decl.getType() != varType)
604  return op->emitOpError() << "expected " << operandName << " (" << varType
605  << ") to be the same type as " << operandName
606  << " declaration (" << decl.getType() << ")";
607  }
608 
609  return success();
610 }
611 
612 unsigned ParallelOp::getNumDataOperands() {
613  return getReductionOperands().size() + getGangPrivateOperands().size() +
614  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
615 }
616 
617 Value ParallelOp::getDataOperand(unsigned i) {
618  unsigned numOptional = getAsync() ? 1 : 0;
619  numOptional += getNumGangs().size();
620  numOptional += getNumWorkers() ? 1 : 0;
621  numOptional += getVectorLength() ? 1 : 0;
622  numOptional += getIfCond() ? 1 : 0;
623  numOptional += getSelfCond() ? 1 : 0;
624  return getOperand(getWaitOperands().size() + numOptional + i);
625 }
626 
628  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
629  *this, getPrivatizations(), getGangPrivateOperands(), "private",
630  "privatizations", /*checkOperandType=*/false)))
631  return failure();
632  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
633  *this, getReductionRecipes(), getReductionOperands(), "reduction",
634  "reductions", false)))
635  return failure();
636  if (getNumGangs().size() > 3)
637  return emitOpError() << "num_gangs expects a maximum of 3 values";
638  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // SerialOp
643 //===----------------------------------------------------------------------===//
644 
645 unsigned SerialOp::getNumDataOperands() {
646  return getReductionOperands().size() + getGangPrivateOperands().size() +
647  getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
648 }
649 
650 Value SerialOp::getDataOperand(unsigned i) {
651  unsigned numOptional = getAsync() ? 1 : 0;
652  numOptional += getIfCond() ? 1 : 0;
653  numOptional += getSelfCond() ? 1 : 0;
654  return getOperand(getWaitOperands().size() + numOptional + i);
655 }
656 
658  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
659  *this, getPrivatizations(), getGangPrivateOperands(), "private",
660  "privatizations", /*checkOperandType=*/false)))
661  return failure();
662  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
663  *this, getReductionRecipes(), getReductionOperands(), "reduction",
664  "reductions", false)))
665  return failure();
666  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // KernelsOp
671 //===----------------------------------------------------------------------===//
672 
673 unsigned KernelsOp::getNumDataOperands() {
674  return getDataClauseOperands().size();
675 }
676 
677 Value KernelsOp::getDataOperand(unsigned i) {
678  unsigned numOptional = getAsync() ? 1 : 0;
679  numOptional += getWaitOperands().size();
680  numOptional += getNumGangs().size();
681  numOptional += getNumWorkers() ? 1 : 0;
682  numOptional += getVectorLength() ? 1 : 0;
683  numOptional += getIfCond() ? 1 : 0;
684  numOptional += getSelfCond() ? 1 : 0;
685  return getOperand(numOptional + i);
686 }
687 
689  if (getNumGangs().size() > 3)
690  return emitOpError() << "num_gangs expects a maximum of 3 values";
691  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // HostDataOp
696 //===----------------------------------------------------------------------===//
697 
699  if (getDataClauseOperands().empty())
700  return emitError("at least one operand must appear on the host_data "
701  "operation");
702 
703  for (mlir::Value operand : getDataClauseOperands())
704  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
705  return emitError("expect data entry operation as defining op");
706  return success();
707 }
708 
709 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
710  MLIRContext *context) {
711  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // LoopOp
716 //===----------------------------------------------------------------------===//
717 
718 static ParseResult
719 parseGangValue(OpAsmParser &parser, llvm::StringRef keyword,
720  std::optional<OpAsmParser::UnresolvedOperand> &value,
721  Type &valueType, bool &needComa, bool &newValue) {
722  if (succeeded(parser.parseOptionalKeyword(keyword))) {
723  if (parser.parseEqual())
724  return failure();
726  if (parser.parseOperand(*value) || parser.parseColonType(valueType))
727  return failure();
728  needComa = true;
729  newValue = true;
730  }
731  return success();
732 }
733 
735  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &gangNum,
736  Type &gangNumType, std::optional<OpAsmParser::UnresolvedOperand> &gangDim,
737  Type &gangDimType,
738  std::optional<OpAsmParser::UnresolvedOperand> &gangStatic,
739  Type &gangStaticType, UnitAttr &hasGang) {
740  hasGang = UnitAttr::get(parser.getBuilder().getContext());
741  gangNum = std::nullopt;
742  gangDim = std::nullopt;
743  gangStatic = std::nullopt;
744  bool needComa = false;
745 
746  // optional gang operands
747  if (succeeded(parser.parseOptionalLParen())) {
748  while (true) {
749  bool newValue = false;
750  bool needValue = false;
751  if (needComa) {
752  if (succeeded(parser.parseOptionalComma()))
753  needValue = true; // expect a new value after comma.
754  else
755  break;
756  }
757 
758  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), gangNum,
759  gangNumType, needComa, newValue)))
760  return failure();
761  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), gangDim,
762  gangDimType, needComa, newValue)))
763  return failure();
764  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
765  gangStatic, gangStaticType, needComa,
766  newValue)))
767  return failure();
768 
769  if (!newValue && needValue) {
770  parser.emitError(parser.getCurrentLocation(),
771  "new value expected after comma");
772  return failure();
773  }
774 
775  if (!newValue)
776  break;
777  }
778 
779  if (!gangNum && !gangDim && !gangStatic) {
780  parser.emitError(parser.getCurrentLocation(),
781  "expect at least one of num, dim or static values");
782  return failure();
783  }
784 
785  if (failed(parser.parseRParen()))
786  return failure();
787  }
788  return success();
789 }
790 
792  Type gangNumType, Value gangDim, Type gangDimType,
793  Value gangStatic, Type gangStaticType, UnitAttr hasGang) {
794  if (gangNum || gangStatic || gangDim) {
795  p << "(";
796  if (gangNum) {
797  p << LoopOp::getGangNumKeyword() << "=" << gangNum << " : "
798  << gangNumType;
799  if (gangStatic || gangDim)
800  p << ", ";
801  }
802  if (gangDim) {
803  p << LoopOp::getGangDimKeyword() << "=" << gangDim << " : "
804  << gangDimType;
805  if (gangStatic)
806  p << ", ";
807  }
808  if (gangStatic)
809  p << LoopOp::getGangStaticKeyword() << "=" << gangStatic << " : "
810  << gangStaticType;
811  p << ")";
812  }
813 }
814 
815 static ParseResult
817  std::optional<OpAsmParser::UnresolvedOperand> &workerNum,
818  Type &workerNumType, UnitAttr &hasWorker) {
819  hasWorker = UnitAttr::get(parser.getBuilder().getContext());
820  if (succeeded(parser.parseOptionalLParen())) {
821  workerNum = OpAsmParser::UnresolvedOperand{};
822  if (parser.parseOperand(*workerNum) ||
823  parser.parseColonType(workerNumType) || parser.parseRParen())
824  return failure();
825  }
826  return success();
827 }
828 
830  Type workerNumType, UnitAttr hasWorker) {
831  if (workerNum)
832  p << "(" << workerNum << " : " << workerNumType << ")";
833 }
834 
835 static ParseResult
837  std::optional<OpAsmParser::UnresolvedOperand> &vectorLength,
838  Type &vectorLengthType, UnitAttr &hasVector) {
839  hasVector = UnitAttr::get(parser.getBuilder().getContext());
840  if (succeeded(parser.parseOptionalLParen())) {
841  vectorLength = OpAsmParser::UnresolvedOperand{};
842  if (parser.parseOperand(*vectorLength) ||
843  parser.parseColonType(vectorLengthType) || parser.parseRParen())
844  return failure();
845  }
846  return success();
847 }
848 
849 void printVectorClause(OpAsmPrinter &p, Operation *op, Value vectorLength,
850  Type vectorLengthType, UnitAttr hasVector) {
851  if (vectorLength)
852  p << "(" << vectorLength << " : " << vectorLengthType << ")";
853 }
854 
856  // auto, independent and seq attribute are mutually exclusive.
857  if ((getAuto_() && (getIndependent() || getSeq())) ||
858  (getIndependent() && getSeq())) {
859  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
860  << "\", " << getIndependentAttrName() << ", "
861  << getSeqAttrName()
862  << " can be present at the same time";
863  }
864 
865  // Gang, worker and vector are incompatible with seq.
866  if (getSeq() && (getHasGang() || getHasWorker() || getHasVector()))
867  return emitError("gang, worker or vector cannot appear with the seq attr");
868 
869  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
870  *this, getPrivatizations(), getPrivateOperands(), "private",
871  "privatizations", false)))
872  return failure();
873 
874  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
875  *this, getReductionRecipes(), getReductionOperands(), "reduction",
876  "reductions", false)))
877  return failure();
878 
879  // Check non-empty body().
880  if (getRegion().empty())
881  return emitError("expected non-empty body.");
882 
883  return success();
884 }
885 
886 unsigned LoopOp::getNumDataOperands() {
887  return getReductionOperands().size() + getPrivateOperands().size();
888 }
889 
890 Value LoopOp::getDataOperand(unsigned i) {
891  unsigned numOptional = getGangNum() ? 1 : 0;
892  numOptional += getGangDim() ? 1 : 0;
893  numOptional += getGangStatic() ? 1 : 0;
894  numOptional += getVectorLength() ? 1 : 0;
895  numOptional += getWorkerNum() ? 1 : 0;
896  numOptional += getTileOperands().size();
897  numOptional += getCacheOperands().size();
898  return getOperand(numOptional + i);
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // DataOp
903 //===----------------------------------------------------------------------===//
904 
906  // 2.6.5. Data Construct restriction
907  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
908  // attach, or default clause must appear on a data construct.
909  if (getOperands().empty() && !getDefaultAttr())
910  return emitError("at least one operand or the default attribute "
911  "must appear on the data operation");
912 
913  for (mlir::Value operand : getDataClauseOperands())
914  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
915  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
916  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
917  operand.getDefiningOp()))
918  return emitError("expect data entry/exit operation or acc.getdeviceptr "
919  "as defining op");
920 
921  return success();
922 }
923 
924 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
925 
926 Value DataOp::getDataOperand(unsigned i) {
927  unsigned numOptional = getIfCond() ? 1 : 0;
928  numOptional += getAsync() ? 1 : 0;
929  numOptional += getWaitOperands().size();
930  return getOperand(numOptional + i);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // ExitDataOp
935 //===----------------------------------------------------------------------===//
936 
938  // 2.6.6. Data Exit Directive restriction
939  // At least one copyout, delete, or detach clause must appear on an exit data
940  // directive.
941  if (getDataClauseOperands().empty())
942  return emitError("at least one operand must be present in dataOperands on "
943  "the exit data operation");
944 
945  // The async attribute represent the async clause without value. Therefore the
946  // attribute and operand cannot appear at the same time.
947  if (getAsyncOperand() && getAsync())
948  return emitError("async attribute cannot appear with asyncOperand");
949 
950  // The wait attribute represent the wait clause without values. Therefore the
951  // attribute and operands cannot appear at the same time.
952  if (!getWaitOperands().empty() && getWait())
953  return emitError("wait attribute cannot appear with waitOperands");
954 
955  if (getWaitDevnum() && getWaitOperands().empty())
956  return emitError("wait_devnum cannot appear without waitOperands");
957 
958  return success();
959 }
960 
961 unsigned ExitDataOp::getNumDataOperands() {
962  return getDataClauseOperands().size();
963 }
964 
965 Value ExitDataOp::getDataOperand(unsigned i) {
966  unsigned numOptional = getIfCond() ? 1 : 0;
967  numOptional += getAsyncOperand() ? 1 : 0;
968  numOptional += getWaitDevnum() ? 1 : 0;
969  return getOperand(getWaitOperands().size() + numOptional + i);
970 }
971 
972 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
973  MLIRContext *context) {
974  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // EnterDataOp
979 //===----------------------------------------------------------------------===//
980 
982  // 2.6.6. Data Enter Directive restriction
983  // At least one copyin, create, or attach clause must appear on an enter data
984  // directive.
985  if (getDataClauseOperands().empty())
986  return emitError("at least one operand must be present in dataOperands on "
987  "the enter data operation");
988 
989  // The async attribute represent the async clause without value. Therefore the
990  // attribute and operand cannot appear at the same time.
991  if (getAsyncOperand() && getAsync())
992  return emitError("async attribute cannot appear with asyncOperand");
993 
994  // The wait attribute represent the wait clause without values. Therefore the
995  // attribute and operands cannot appear at the same time.
996  if (!getWaitOperands().empty() && getWait())
997  return emitError("wait attribute cannot appear with waitOperands");
998 
999  if (getWaitDevnum() && getWaitOperands().empty())
1000  return emitError("wait_devnum cannot appear without waitOperands");
1001 
1002  for (mlir::Value operand : getDataClauseOperands())
1003  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
1004  operand.getDefiningOp()))
1005  return emitError("expect data entry operation as defining op");
1006 
1007  return success();
1008 }
1009 
1010 unsigned EnterDataOp::getNumDataOperands() {
1011  return getDataClauseOperands().size();
1012 }
1013 
1014 Value EnterDataOp::getDataOperand(unsigned i) {
1015  unsigned numOptional = getIfCond() ? 1 : 0;
1016  numOptional += getAsyncOperand() ? 1 : 0;
1017  numOptional += getWaitDevnum() ? 1 : 0;
1018  return getOperand(getWaitOperands().size() + numOptional + i);
1019 }
1020 
1021 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1022  MLIRContext *context) {
1023  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // AtomicReadOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // AtomicWriteOp
1034 //===----------------------------------------------------------------------===//
1035 
1036 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // AtomicUpdateOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1043  PatternRewriter &rewriter) {
1044  if (op.isNoOp()) {
1045  rewriter.eraseOp(op);
1046  return success();
1047  }
1048 
1049  if (Value writeVal = op.getWriteOpVal()) {
1050  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
1051  return success();
1052  }
1053 
1054  return failure();
1055 }
1056 
1057 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
1058 
1059 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // AtomicCaptureOp
1063 //===----------------------------------------------------------------------===//
1064 
1065 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1066  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1067  return op;
1068  return dyn_cast<AtomicReadOp>(getSecondOp());
1069 }
1070 
1071 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1072  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1073  return op;
1074  return dyn_cast<AtomicWriteOp>(getSecondOp());
1075 }
1076 
1077 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1078  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1079  return op;
1080  return dyn_cast<AtomicUpdateOp>(getSecondOp());
1081 }
1082 
1083 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // DeclareEnterOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 template <typename Op>
1090 static LogicalResult
1092  bool requireAtLeastOneOperand = true) {
1093  if (operands.empty() && requireAtLeastOneOperand)
1094  return emitError(
1095  op->getLoc(),
1096  "at least one operand must appear on the declare operation");
1097 
1098  for (mlir::Value operand : operands) {
1099  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1100  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
1101  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
1102  operand.getDefiningOp()))
1103  return op.emitError(
1104  "expect valid declare data entry operation or acc.getdeviceptr "
1105  "as defining op");
1106 
1107  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
1108  assert(varPtr && "declare operands can only be data entry operations which "
1109  "must have varPtr");
1110  std::optional<mlir::acc::DataClause> dataClauseOptional{
1111  getDataClause(operand.getDefiningOp())};
1112  assert(dataClauseOptional.has_value() &&
1113  "declare operands can only be data entry operations which must have "
1114  "dataClause");
1115 
1116  // If varPtr has no defining op - there is nothing to check further.
1117  if (!varPtr.getDefiningOp())
1118  continue;
1119 
1120  // Check that the varPtr has a declare attribute.
1121  auto declareAttribute{
1122  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
1123  if (!declareAttribute)
1124  return op.emitError(
1125  "expect declare attribute on variable in declare operation");
1126 
1127  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
1128  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
1129  return op.emitError(
1130  "expect matching declare attribute on variable in declare operation");
1131 
1132  // If the variable is marked with implicit attribute, the matching declare
1133  // data action must also be marked implicit. The reverse is not checked
1134  // since implicit data action may be inserted to do actions like updating
1135  // device copy, in which case the variable is not necessarily implicitly
1136  // declare'd.
1137  if (declAttr.getImplicit() &&
1138  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
1139  return op.emitError(
1140  "implicitness must match between declare op and flag on variable");
1141  }
1142 
1143  return success();
1144 }
1145 
1147  return checkDeclareOperands(*this, this->getDataClauseOperands());
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // DeclareExitOp
1152 //===----------------------------------------------------------------------===//
1153 
1155  if (getToken())
1156  return checkDeclareOperands(*this, this->getDataClauseOperands(),
1157  /*requireAtLeastOneOperand=*/false);
1158  return checkDeclareOperands(*this, this->getDataClauseOperands());
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // DeclareOp
1163 //===----------------------------------------------------------------------===//
1164 
1166  return checkDeclareOperands(*this, this->getDataClauseOperands());
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // RoutineOp
1171 //===----------------------------------------------------------------------===//
1172 
1174  int parallelism = 0;
1175  parallelism += getGang() ? 1 : 0;
1176  parallelism += getWorker() ? 1 : 0;
1177  parallelism += getVector() ? 1 : 0;
1178  parallelism += getSeq() ? 1 : 0;
1179 
1180  if (parallelism > 1)
1181  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
1182  "be present at the same time";
1183 
1184  return success();
1185 }
1186 
1187 static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang,
1188  IntegerAttr &gangDim) {
1189  // Since gang clause exists, ensure that unit attribute is set.
1190  gang = UnitAttr::get(parser.getBuilder().getContext());
1191 
1192  // Next, look for dim on gang. Don't initialize `gangDim` yet since
1193  // we leave it without attribute if there is no `dim` specifier.
1194  if (succeeded(parser.parseOptionalLParen())) {
1195  // Look for syntax that looks like `dim = 1 : i32`.
1196  // Thus first look for `dim =`
1197  if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) ||
1198  failed(parser.parseEqual()))
1199  return failure();
1200 
1201  int64_t dimValue;
1202  Type valueType;
1203  // Now look for `1 : i32`
1204  if (failed(parser.parseInteger(dimValue)) ||
1205  failed(parser.parseColonType(valueType)))
1206  return failure();
1207 
1208  gangDim = IntegerAttr::get(valueType, dimValue);
1209 
1210  if (failed(parser.parseRParen()))
1211  return failure();
1212  }
1213 
1214  return success();
1215 }
1216 
1217 void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang,
1218  IntegerAttr gangDim) {
1219  if (gangDim)
1220  p << "(" << RoutineOp::getGangDimKeyword() << " = " << gangDim.getValue()
1221  << " : " << gangDim.getType() << ")";
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // InitOp
1226 //===----------------------------------------------------------------------===//
1227 
1229  Operation *currOp = *this;
1230  while ((currOp = currOp->getParentOp()))
1231  if (isComputeOperation(currOp))
1232  return emitOpError("cannot be nested in a compute operation");
1233  return success();
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // ShutdownOp
1238 //===----------------------------------------------------------------------===//
1239 
1241  Operation *currOp = *this;
1242  while ((currOp = currOp->getParentOp()))
1243  if (isComputeOperation(currOp))
1244  return emitOpError("cannot be nested in a compute operation");
1245  return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // SetOp
1250 //===----------------------------------------------------------------------===//
1251 
1253  Operation *currOp = *this;
1254  while ((currOp = currOp->getParentOp()))
1255  if (isComputeOperation(currOp))
1256  return emitOpError("cannot be nested in a compute operation");
1257  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
1258  return emitOpError("at least one default_async, device_num, or device_type "
1259  "operand must appear");
1260  return success();
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // UpdateOp
1265 //===----------------------------------------------------------------------===//
1266 
1268  // At least one of host or device should have a value.
1269  if (getDataClauseOperands().empty())
1270  return emitError("at least one value must be present in dataOperands");
1271 
1272  // The async attribute represent the async clause without value. Therefore the
1273  // attribute and operand cannot appear at the same time.
1274  if (getAsyncOperand() && getAsync())
1275  return emitError("async attribute cannot appear with asyncOperand");
1276 
1277  // The wait attribute represent the wait clause without values. Therefore the
1278  // attribute and operands cannot appear at the same time.
1279  if (!getWaitOperands().empty() && getWait())
1280  return emitError("wait attribute cannot appear with waitOperands");
1281 
1282  if (getWaitDevnum() && getWaitOperands().empty())
1283  return emitError("wait_devnum cannot appear without waitOperands");
1284 
1285  for (mlir::Value operand : getDataClauseOperands())
1286  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
1287  operand.getDefiningOp()))
1288  return emitError("expect data entry/exit operation or acc.getdeviceptr "
1289  "as defining op");
1290 
1291  return success();
1292 }
1293 
1294 unsigned UpdateOp::getNumDataOperands() {
1295  return getDataClauseOperands().size();
1296 }
1297 
1298 Value UpdateOp::getDataOperand(unsigned i) {
1299  unsigned numOptional = getAsyncOperand() ? 1 : 0;
1300  numOptional += getWaitDevnum() ? 1 : 0;
1301  numOptional += getIfCond() ? 1 : 0;
1302  return getOperand(getWaitOperands().size() + numOptional + i);
1303 }
1304 
1305 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1306  MLIRContext *context) {
1307  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
1308 }
1309 
1310 //===----------------------------------------------------------------------===//
1311 // WaitOp
1312 //===----------------------------------------------------------------------===//
1313 
1315  // The async attribute represent the async clause without value. Therefore the
1316  // attribute and operand cannot appear at the same time.
1317  if (getAsyncOperand() && getAsync())
1318  return emitError("async attribute cannot appear with asyncOperand");
1319 
1320  if (getWaitDevnum() && getWaitOperands().empty())
1321  return emitError("wait_devnum cannot appear without waitOperands");
1322 
1323  return success();
1324 }
1325 
1326 #define GET_OP_CLASSES
1327 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
1328 
1329 #define GET_ATTRDEF_CLASSES
1330 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
1331 
1332 #define GET_TYPEDEF_CLASSES
1333 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
1334 
1335 //===----------------------------------------------------------------------===//
1336 // acc dialect utilities
1337 //===----------------------------------------------------------------------===//
1338 
1340  auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataEntryOp)
1341  .Case<ACC_DATA_ENTRY_OPS>(
1342  [&](auto entry) { return entry.getVarPtr(); })
1343  .Default([&](mlir::Operation *) { return mlir::Value(); })};
1344  return varPtr;
1345 }
1346 
1347 std::optional<mlir::acc::DataClause>
1349  auto dataClause{
1351  accDataEntryOp)
1352  .Case<ACC_DATA_ENTRY_OPS>(
1353  [&](auto entry) { return entry.getDataClause(); })
1354  .Default([&](mlir::Operation *) { return std::nullopt; })};
1355  return dataClause;
1356 }
1357 
1359  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
1360  .Case<ACC_DATA_ENTRY_OPS>(
1361  [&](auto entry) { return entry.getImplicit(); })
1362  .Default([&](mlir::Operation *) { return false; })};
1363  return implicit;
1364 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:105
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1943
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:324
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:338
static ParseResult parseWorkerClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &workerNum, Type &workerNumType, UnitAttr &hasWorker)
Definition: OpenACC.cpp:816
void printVectorClause(OpAsmPrinter &p, Operation *op, Value vectorLength, Type vectorLengthType, UnitAttr hasVector)
Definition: OpenACC.cpp:849
static ParseResult parseVectorClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &vectorLength, Type &vectorLengthType, UnitAttr &hasVector)
Definition: OpenACC.cpp:836
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, std::optional< OpAsmParser::UnresolvedOperand > &value, Type &valueType, bool &needComa, bool &newValue)
Definition: OpenACC.cpp:719
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:1091
static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang, IntegerAttr &gangDim)
Definition: OpenACC.cpp:1187
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:557
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:539
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:414
static ParseResult parseGangClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &gangNum, Type &gangNumType, std::optional< OpAsmParser::UnresolvedOperand > &gangDim, Type &gangDimType, std::optional< OpAsmParser::UnresolvedOperand > &gangStatic, Type &gangStaticType, UnitAttr &hasGang)
Definition: OpenACC.cpp:734
void printGangClause(OpAsmPrinter &p, Operation *op, Value gangNum, Type gangNumType, Value gangDim, Type gangDimType, Value gangStatic, Type gangStaticType, UnitAttr hasGang)
Definition: OpenACC.cpp:791
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:572
void printWorkerClause(OpAsmPrinter &p, Operation *op, Value workerNum, Type workerNumType, UnitAttr hasWorker)
Definition: OpenACC.cpp:829
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:519
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang, IntegerAttr gangDim)
Definition: OpenACC.cpp:1217
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:41
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
Operation & front()
Definition: Block.h:146
MLIRContext * getContext() const
Definition: Builders.h:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:1348
mlir::Value getVarPtr(mlir::Operation *accDataEntryOp)
Used to obtain the varPtr from a data entry operation.
Definition: OpenACC.cpp:1339
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:1358
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:91
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.