MLIR  20.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
24 
25 using namespace mlir;
26 using namespace acc;
27 
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33 
34 namespace {
35 struct MemRefPointerLikeModel
36  : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
37  MemRefType> {
38  Type getElementType(Type pointer) const {
39  return llvm::cast<MemRefType>(pointer).getElementType();
40  }
41 };
42 
43 struct LLVMPointerPointerLikeModel
44  : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
45  LLVM::LLVMPointerType> {
46  Type getElementType(Type pointer) const { return Type(); }
47 };
48 } // namespace
49 
50 //===----------------------------------------------------------------------===//
51 // OpenACC operations
52 //===----------------------------------------------------------------------===//
53 
54 void OpenACCDialect::initialize() {
55  addOperations<
56 #define GET_OP_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
58  >();
59  addAttributes<
60 #define GET_ATTRDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
62  >();
63  addTypes<
64 #define GET_TYPEDEF_LIST
65 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
66  >();
67 
68  // By attaching interfaces here, we make the OpenACC dialect dependent on
69  // the other dialects. This is probably better than having dialects like LLVM
70  // and memref be dependent on OpenACC.
71  MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
72  LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
73  *getContext());
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // device_type support helpers
78 //===----------------------------------------------------------------------===//
79 
80 static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
81  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
82  return true;
83  return false;
84 }
85 
86 static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
87  mlir::acc::DeviceType deviceType) {
88  if (!hasDeviceTypeValues(arrayAttr))
89  return false;
90 
91  for (auto attr : *arrayAttr) {
92  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
93  if (deviceTypeAttr.getValue() == deviceType)
94  return true;
95  }
96 
97  return false;
98 }
99 
101  std::optional<mlir::ArrayAttr> deviceTypes) {
102  if (!hasDeviceTypeValues(deviceTypes))
103  return;
104 
105  p << "[";
106  llvm::interleaveComma(*deviceTypes, p,
107  [&](mlir::Attribute attr) { p << attr; });
108  p << "]";
109 }
110 
111 static std::optional<unsigned> findSegment(ArrayAttr segments,
112  mlir::acc::DeviceType deviceType) {
113  unsigned segmentIdx = 0;
114  for (auto attr : segments) {
115  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
116  if (deviceTypeAttr.getValue() == deviceType)
117  return std::make_optional(segmentIdx);
118  ++segmentIdx;
119  }
120  return std::nullopt;
121 }
122 
124 getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
126  std::optional<llvm::ArrayRef<int32_t>> segments,
127  mlir::acc::DeviceType deviceType) {
128  if (!arrayAttr)
129  return range.take_front(0);
130  if (auto pos = findSegment(*arrayAttr, deviceType)) {
131  int32_t nbOperandsBefore = 0;
132  for (unsigned i = 0; i < *pos; ++i)
133  nbOperandsBefore += (*segments)[i];
134  return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
135  }
136  return range.take_front(0);
137 }
138 
139 static mlir::Value
140 getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
142  std::optional<llvm::ArrayRef<int32_t>> segments,
143  std::optional<mlir::ArrayAttr> hasWaitDevnum,
144  mlir::acc::DeviceType deviceType) {
145  if (!hasDeviceTypeValues(deviceTypeAttr))
146  return {};
147  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
148  if (hasWaitDevnum->getValue()[*pos])
149  return getValuesFromSegments(deviceTypeAttr, operands, segments,
150  deviceType)
151  .front();
152  return {};
153 }
154 
156 getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
158  std::optional<llvm::ArrayRef<int32_t>> segments,
159  std::optional<mlir::ArrayAttr> hasWaitDevnum,
160  mlir::acc::DeviceType deviceType) {
161  auto range =
162  getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
163  if (range.empty())
164  return range;
165  if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
166  if (hasWaitDevnum && *hasWaitDevnum) {
167  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
168  if (boolAttr.getValue())
169  return range.drop_front(1); // first value is devnum
170  }
171  }
172  return range;
173 }
174 
175 template <typename Op>
176 static LogicalResult checkWaitAndAsyncConflict(Op op) {
177  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
178  ++dtypeInt) {
179  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
180 
181  // The async attribute represent the async clause without value. Therefore
182  // the attribute and operand cannot appear at the same time.
183  if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
184  op.hasAsyncOnly(dtype))
185  return op.emitError("async attribute cannot appear with asyncOperand");
186 
187  // The wait attribute represent the wait clause without values. Therefore
188  // the attribute and operands cannot appear at the same time.
189  if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
190  op.hasWaitOnly(dtype))
191  return op.emitError("wait attribute cannot appear with waitOperands");
192  }
193  return success();
194 }
195 
196 template <typename Op>
197 static LogicalResult checkVarAndVarType(Op op) {
198  if (!op.getVar())
199  return op.emitError("must have var operand");
200 
201  if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
202  mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
203  // TODO: If a type implements both interfaces (mappable and pointer-like),
204  // it is unclear which semantics to apply without additional info which
205  // would need captured in the data operation. For now restrict this case
206  // unless a compelling reason to support disambiguating between the two.
207  return op.emitError("var must be mappable or pointer-like (not both)");
208  }
209 
210  if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
211  !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
212  return op.emitError("var must be mappable or pointer-like");
213 
214  if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
215  op.getVarType() != op.getVar().getType())
216  return op.emitError("varType must match when var is mappable");
217 
218  return success();
219 }
220 
221 template <typename Op>
222 static LogicalResult checkVarAndAccVar(Op op) {
223  if (op.getVar().getType() != op.getAccVar().getType())
224  return op.emitError("input and output types must match");
225 
226  return success();
227 }
228 
229 static ParseResult parseVar(mlir::OpAsmParser &parser,
231  // Either `var` or `varPtr` keyword is required.
232  if (failed(parser.parseOptionalKeyword("varPtr"))) {
233  if (failed(parser.parseKeyword("var")))
234  return failure();
235  }
236  if (failed(parser.parseLParen()))
237  return failure();
238  if (failed(parser.parseOperand(var)))
239  return failure();
240 
241  return success();
242 }
243 
245  mlir::Value var) {
246  if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
247  p << "varPtr(";
248  else
249  p << "var(";
250  p.printOperand(var);
251 }
252 
253 static ParseResult parseAccVar(mlir::OpAsmParser &parser,
255  mlir::Type &accVarType) {
256  // Either `accVar` or `accPtr` keyword is required.
257  if (failed(parser.parseOptionalKeyword("accPtr"))) {
258  if (failed(parser.parseKeyword("accVar")))
259  return failure();
260  }
261  if (failed(parser.parseLParen()))
262  return failure();
263  if (failed(parser.parseOperand(var)))
264  return failure();
265  if (failed(parser.parseColon()))
266  return failure();
267  if (failed(parser.parseType(accVarType)))
268  return failure();
269  if (failed(parser.parseRParen()))
270  return failure();
271 
272  return success();
273 }
274 
276  mlir::Value accVar, mlir::Type accVarType) {
277  if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
278  p << "accPtr(";
279  else
280  p << "accVar(";
281  p.printOperand(accVar);
282  p << " : ";
283  p.printType(accVarType);
284  p << ")";
285 }
286 
287 static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
288  mlir::Type &varPtrType,
289  mlir::TypeAttr &varTypeAttr) {
290  if (failed(parser.parseType(varPtrType)))
291  return failure();
292  if (failed(parser.parseRParen()))
293  return failure();
294 
295  if (succeeded(parser.parseOptionalKeyword("varType"))) {
296  if (failed(parser.parseLParen()))
297  return failure();
298  mlir::Type varType;
299  if (failed(parser.parseType(varType)))
300  return failure();
301  varTypeAttr = mlir::TypeAttr::get(varType);
302  if (failed(parser.parseRParen()))
303  return failure();
304  } else {
305  // Set `varType` from the element type of the type of `varPtr`.
306  if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
307  varTypeAttr = mlir::TypeAttr::get(
308  mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
309  else
310  varTypeAttr = mlir::TypeAttr::get(varPtrType);
311  }
312 
313  return success();
314 }
315 
317  mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
318  p.printType(varPtrType);
319  p << ")";
320 
321  // Print the `varType` only if it differs from the element type of
322  // `varPtr`'s type.
323  mlir::Type varType = varTypeAttr.getValue();
324  mlir::Type typeToCheckAgainst =
325  mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
326  ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
327  : varPtrType;
328  if (typeToCheckAgainst != varType) {
329  p << " varType(";
330  p.printType(varType);
331  p << ")";
332  }
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // DataBoundsOp
337 //===----------------------------------------------------------------------===//
338 LogicalResult acc::DataBoundsOp::verify() {
339  auto extent = getExtent();
340  auto upperbound = getUpperbound();
341  if (!extent && !upperbound)
342  return emitError("expected extent or upperbound.");
343  return success();
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // PrivateOp
348 //===----------------------------------------------------------------------===//
349 LogicalResult acc::PrivateOp::verify() {
350  if (getDataClause() != acc::DataClause::acc_private)
351  return emitError(
352  "data clause associated with private operation must match its intent");
353  if (failed(checkVarAndVarType(*this)))
354  return failure();
355  return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // FirstprivateOp
360 //===----------------------------------------------------------------------===//
361 LogicalResult acc::FirstprivateOp::verify() {
362  if (getDataClause() != acc::DataClause::acc_firstprivate)
363  return emitError("data clause associated with firstprivate operation must "
364  "match its intent");
365  if (failed(checkVarAndVarType(*this)))
366  return failure();
367  return success();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // ReductionOp
372 //===----------------------------------------------------------------------===//
373 LogicalResult acc::ReductionOp::verify() {
374  if (getDataClause() != acc::DataClause::acc_reduction)
375  return emitError("data clause associated with reduction operation must "
376  "match its intent");
377  if (failed(checkVarAndVarType(*this)))
378  return failure();
379  return success();
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // DevicePtrOp
384 //===----------------------------------------------------------------------===//
385 LogicalResult acc::DevicePtrOp::verify() {
386  if (getDataClause() != acc::DataClause::acc_deviceptr)
387  return emitError("data clause associated with deviceptr operation must "
388  "match its intent");
389  if (failed(checkVarAndVarType(*this)))
390  return failure();
391  if (failed(checkVarAndAccVar(*this)))
392  return failure();
393  return success();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // PresentOp
398 //===----------------------------------------------------------------------===//
399 LogicalResult acc::PresentOp::verify() {
400  if (getDataClause() != acc::DataClause::acc_present)
401  return emitError(
402  "data clause associated with present operation must match its intent");
403  if (failed(checkVarAndVarType(*this)))
404  return failure();
405  if (failed(checkVarAndAccVar(*this)))
406  return failure();
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // CopyinOp
412 //===----------------------------------------------------------------------===//
413 LogicalResult acc::CopyinOp::verify() {
414  // Test for all clauses this operation can be decomposed from:
415  if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
416  getDataClause() != acc::DataClause::acc_copyin_readonly &&
417  getDataClause() != acc::DataClause::acc_copy &&
418  getDataClause() != acc::DataClause::acc_reduction)
419  return emitError(
420  "data clause associated with copyin operation must match its intent"
421  " or specify original clause this operation was decomposed from");
422  if (failed(checkVarAndVarType(*this)))
423  return failure();
424  if (failed(checkVarAndAccVar(*this)))
425  return failure();
426  return success();
427 }
428 
429 bool acc::CopyinOp::isCopyinReadonly() {
430  return getDataClause() == acc::DataClause::acc_copyin_readonly;
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // CreateOp
435 //===----------------------------------------------------------------------===//
436 LogicalResult acc::CreateOp::verify() {
437  // Test for all clauses this operation can be decomposed from:
438  if (getDataClause() != acc::DataClause::acc_create &&
439  getDataClause() != acc::DataClause::acc_create_zero &&
440  getDataClause() != acc::DataClause::acc_copyout &&
441  getDataClause() != acc::DataClause::acc_copyout_zero)
442  return emitError(
443  "data clause associated with create operation must match its intent"
444  " or specify original clause this operation was decomposed from");
445  if (failed(checkVarAndVarType(*this)))
446  return failure();
447  if (failed(checkVarAndAccVar(*this)))
448  return failure();
449  return success();
450 }
451 
452 bool acc::CreateOp::isCreateZero() {
453  // The zero modifier is encoded in the data clause.
454  return getDataClause() == acc::DataClause::acc_create_zero ||
455  getDataClause() == acc::DataClause::acc_copyout_zero;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // NoCreateOp
460 //===----------------------------------------------------------------------===//
461 LogicalResult acc::NoCreateOp::verify() {
462  if (getDataClause() != acc::DataClause::acc_no_create)
463  return emitError("data clause associated with no_create operation must "
464  "match its intent");
465  if (failed(checkVarAndVarType(*this)))
466  return failure();
467  if (failed(checkVarAndAccVar(*this)))
468  return failure();
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // AttachOp
474 //===----------------------------------------------------------------------===//
475 LogicalResult acc::AttachOp::verify() {
476  if (getDataClause() != acc::DataClause::acc_attach)
477  return emitError(
478  "data clause associated with attach operation must match its intent");
479  if (failed(checkVarAndVarType(*this)))
480  return failure();
481  if (failed(checkVarAndAccVar(*this)))
482  return failure();
483  return success();
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // DeclareDeviceResidentOp
488 //===----------------------------------------------------------------------===//
489 
490 LogicalResult acc::DeclareDeviceResidentOp::verify() {
491  if (getDataClause() != acc::DataClause::acc_declare_device_resident)
492  return emitError("data clause associated with device_resident operation "
493  "must match its intent");
494  if (failed(checkVarAndVarType(*this)))
495  return failure();
496  if (failed(checkVarAndAccVar(*this)))
497  return failure();
498  return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // DeclareLinkOp
503 //===----------------------------------------------------------------------===//
504 
505 LogicalResult acc::DeclareLinkOp::verify() {
506  if (getDataClause() != acc::DataClause::acc_declare_link)
507  return emitError(
508  "data clause associated with link operation must match its intent");
509  if (failed(checkVarAndVarType(*this)))
510  return failure();
511  if (failed(checkVarAndAccVar(*this)))
512  return failure();
513  return success();
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // CopyoutOp
518 //===----------------------------------------------------------------------===//
519 LogicalResult acc::CopyoutOp::verify() {
520  // Test for all clauses this operation can be decomposed from:
521  if (getDataClause() != acc::DataClause::acc_copyout &&
522  getDataClause() != acc::DataClause::acc_copyout_zero &&
523  getDataClause() != acc::DataClause::acc_copy &&
524  getDataClause() != acc::DataClause::acc_reduction)
525  return emitError(
526  "data clause associated with copyout operation must match its intent"
527  " or specify original clause this operation was decomposed from");
528  if (!getVar() || !getAccVar())
529  return emitError("must have both host and device pointers");
530  if (failed(checkVarAndVarType(*this)))
531  return failure();
532  if (failed(checkVarAndAccVar(*this)))
533  return failure();
534  return success();
535 }
536 
537 bool acc::CopyoutOp::isCopyoutZero() {
538  return getDataClause() == acc::DataClause::acc_copyout_zero;
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // DeleteOp
543 //===----------------------------------------------------------------------===//
544 LogicalResult acc::DeleteOp::verify() {
545  // Test for all clauses this operation can be decomposed from:
546  if (getDataClause() != acc::DataClause::acc_delete &&
547  getDataClause() != acc::DataClause::acc_create &&
548  getDataClause() != acc::DataClause::acc_create_zero &&
549  getDataClause() != acc::DataClause::acc_copyin &&
550  getDataClause() != acc::DataClause::acc_copyin_readonly &&
551  getDataClause() != acc::DataClause::acc_present &&
552  getDataClause() != acc::DataClause::acc_declare_device_resident &&
553  getDataClause() != acc::DataClause::acc_declare_link)
554  return emitError(
555  "data clause associated with delete operation must match its intent"
556  " or specify original clause this operation was decomposed from");
557  if (!getAccVar())
558  return emitError("must have device pointer");
559  return success();
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // DetachOp
564 //===----------------------------------------------------------------------===//
565 LogicalResult acc::DetachOp::verify() {
566  // Test for all clauses this operation can be decomposed from:
567  if (getDataClause() != acc::DataClause::acc_detach &&
568  getDataClause() != acc::DataClause::acc_attach)
569  return emitError(
570  "data clause associated with detach operation must match its intent"
571  " or specify original clause this operation was decomposed from");
572  if (!getAccVar())
573  return emitError("must have device pointer");
574  return success();
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // HostOp
579 //===----------------------------------------------------------------------===//
580 LogicalResult acc::UpdateHostOp::verify() {
581  // Test for all clauses this operation can be decomposed from:
582  if (getDataClause() != acc::DataClause::acc_update_host &&
583  getDataClause() != acc::DataClause::acc_update_self)
584  return emitError(
585  "data clause associated with host operation must match its intent"
586  " or specify original clause this operation was decomposed from");
587  if (!getVar() || !getAccVar())
588  return emitError("must have both host and device pointers");
589  if (failed(checkVarAndVarType(*this)))
590  return failure();
591  if (failed(checkVarAndAccVar(*this)))
592  return failure();
593  return success();
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // DeviceOp
598 //===----------------------------------------------------------------------===//
599 LogicalResult acc::UpdateDeviceOp::verify() {
600  // Test for all clauses this operation can be decomposed from:
601  if (getDataClause() != acc::DataClause::acc_update_device)
602  return emitError(
603  "data clause associated with device operation must match its intent"
604  " or specify original clause this operation was decomposed from");
605  if (failed(checkVarAndVarType(*this)))
606  return failure();
607  if (failed(checkVarAndAccVar(*this)))
608  return failure();
609  return success();
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // UseDeviceOp
614 //===----------------------------------------------------------------------===//
615 LogicalResult acc::UseDeviceOp::verify() {
616  // Test for all clauses this operation can be decomposed from:
617  if (getDataClause() != acc::DataClause::acc_use_device)
618  return emitError(
619  "data clause associated with use_device operation must match its intent"
620  " or specify original clause this operation was decomposed from");
621  if (failed(checkVarAndVarType(*this)))
622  return failure();
623  if (failed(checkVarAndAccVar(*this)))
624  return failure();
625  return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // CacheOp
630 //===----------------------------------------------------------------------===//
631 LogicalResult acc::CacheOp::verify() {
632  // Test for all clauses this operation can be decomposed from:
633  if (getDataClause() != acc::DataClause::acc_cache &&
634  getDataClause() != acc::DataClause::acc_cache_readonly)
635  return emitError(
636  "data clause associated with cache operation must match its intent"
637  " or specify original clause this operation was decomposed from");
638  if (failed(checkVarAndVarType(*this)))
639  return failure();
640  if (failed(checkVarAndAccVar(*this)))
641  return failure();
642  return success();
643 }
644 
645 template <typename StructureOp>
646 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
647  unsigned nRegions = 1) {
648 
649  SmallVector<Region *, 2> regions;
650  for (unsigned i = 0; i < nRegions; ++i)
651  regions.push_back(state.addRegion());
652 
653  for (Region *region : regions)
654  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
655  return failure();
656 
657  return success();
658 }
659 
660 static bool isComputeOperation(Operation *op) {
661  return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
662 }
663 
664 namespace {
665 /// Pattern to remove operation without region that have constant false `ifCond`
666 /// and remove the condition from the operation if the `ifCond` is a true
667 /// constant.
668 template <typename OpTy>
669 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
671 
672  LogicalResult matchAndRewrite(OpTy op,
673  PatternRewriter &rewriter) const override {
674  // Early return if there is no condition.
675  Value ifCond = op.getIfCond();
676  if (!ifCond)
677  return failure();
678 
679  IntegerAttr constAttr;
680  if (!matchPattern(ifCond, m_Constant(&constAttr)))
681  return failure();
682  if (constAttr.getInt())
683  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
684  else
685  rewriter.eraseOp(op);
686 
687  return success();
688  }
689 };
690 
691 /// Replaces the given op with the contents of the given single-block region,
692 /// using the operands of the block terminator to replace operation results.
693 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
694  Region &region, ValueRange blockArgs = {}) {
695  assert(llvm::hasSingleElement(region) && "expected single-region block");
696  Block *block = &region.front();
697  Operation *terminator = block->getTerminator();
698  ValueRange results = terminator->getOperands();
699  rewriter.inlineBlockBefore(block, op, blockArgs);
700  rewriter.replaceOp(op, results);
701  rewriter.eraseOp(terminator);
702 }
703 
704 /// Pattern to remove operation with region that have constant false `ifCond`
705 /// and remove the condition from the operation if the `ifCond` is constant
706 /// true.
707 template <typename OpTy>
708 struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
710 
711  LogicalResult matchAndRewrite(OpTy op,
712  PatternRewriter &rewriter) const override {
713  // Early return if there is no condition.
714  Value ifCond = op.getIfCond();
715  if (!ifCond)
716  return failure();
717 
718  IntegerAttr constAttr;
719  if (!matchPattern(ifCond, m_Constant(&constAttr)))
720  return failure();
721  if (constAttr.getInt())
722  rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
723  else
724  replaceOpWithRegion(rewriter, op, op.getRegion());
725 
726  return success();
727  }
728 };
729 
730 } // namespace
731 
732 //===----------------------------------------------------------------------===//
733 // PrivateRecipeOp
734 //===----------------------------------------------------------------------===//
735 
736 static LogicalResult verifyInitLikeSingleArgRegion(
737  Operation *op, Region &region, StringRef regionType, StringRef regionName,
738  Type type, bool verifyYield, bool optional = false) {
739  if (optional && region.empty())
740  return success();
741 
742  if (region.empty())
743  return op->emitOpError() << "expects non-empty " << regionName << " region";
744  Block &firstBlock = region.front();
745  if (firstBlock.getNumArguments() < 1 ||
746  firstBlock.getArgument(0).getType() != type)
747  return op->emitOpError() << "expects " << regionName
748  << " region first "
749  "argument of the "
750  << regionType << " type";
751 
752  if (verifyYield) {
753  for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
754  if (yieldOp.getOperands().size() != 1 ||
755  yieldOp.getOperands().getTypes()[0] != type)
756  return op->emitOpError() << "expects " << regionName
757  << " region to "
758  "yield a value of the "
759  << regionType << " type";
760  }
761  }
762  return success();
763 }
764 
765 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
766  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
767  "privatization", "init", getType(),
768  /*verifyYield=*/false)))
769  return failure();
771  *this, getDestroyRegion(), "privatization", "destroy", getType(),
772  /*verifyYield=*/false, /*optional=*/true)))
773  return failure();
774  return success();
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // FirstprivateRecipeOp
779 //===----------------------------------------------------------------------===//
780 
781 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
782  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
783  "privatization", "init", getType(),
784  /*verifyYield=*/false)))
785  return failure();
786 
787  if (getCopyRegion().empty())
788  return emitOpError() << "expects non-empty copy region";
789 
790  Block &firstBlock = getCopyRegion().front();
791  if (firstBlock.getNumArguments() < 2 ||
792  firstBlock.getArgument(0).getType() != getType())
793  return emitOpError() << "expects copy region with two arguments of the "
794  "privatization type";
795 
796  if (getDestroyRegion().empty())
797  return success();
798 
799  if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
800  "privatization", "destroy",
801  getType(), /*verifyYield=*/false)))
802  return failure();
803 
804  return success();
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // ReductionRecipeOp
809 //===----------------------------------------------------------------------===//
810 
811 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
812  if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
813  "init", getType(),
814  /*verifyYield=*/false)))
815  return failure();
816 
817  if (getCombinerRegion().empty())
818  return emitOpError() << "expects non-empty combiner region";
819 
820  Block &reductionBlock = getCombinerRegion().front();
821  if (reductionBlock.getNumArguments() < 2 ||
822  reductionBlock.getArgument(0).getType() != getType() ||
823  reductionBlock.getArgument(1).getType() != getType())
824  return emitOpError() << "expects combiner region with the first two "
825  << "arguments of the reduction type";
826 
827  for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
828  if (yieldOp.getOperands().size() != 1 ||
829  yieldOp.getOperands().getTypes()[0] != getType())
830  return emitOpError() << "expects combiner region to yield a value "
831  "of the reduction type";
832  }
833 
834  return success();
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // Custom parser and printer verifier for private clause
839 //===----------------------------------------------------------------------===//
840 
841 static ParseResult parseSymOperandList(
842  mlir::OpAsmParser &parser,
844  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
846  if (failed(parser.parseCommaSeparatedList([&]() {
847  if (parser.parseAttribute(attributes.emplace_back()) ||
848  parser.parseArrow() ||
849  parser.parseOperand(operands.emplace_back()) ||
850  parser.parseColonType(types.emplace_back()))
851  return failure();
852  return success();
853  })))
854  return failure();
855  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
856  attributes.end());
857  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
858  return success();
859 }
860 
862  mlir::OperandRange operands,
863  mlir::TypeRange types,
864  std::optional<mlir::ArrayAttr> attributes) {
865  llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
866  p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
867  << std::get<1>(it).getType();
868  });
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // ParallelOp
873 //===----------------------------------------------------------------------===//
874 
875 /// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
876 template <typename Op>
877 static LogicalResult checkDataOperands(Op op,
878  const mlir::ValueRange &operands) {
879  for (mlir::Value operand : operands)
880  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
881  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
882  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
883  operand.getDefiningOp()))
884  return op.emitError(
885  "expect data entry/exit operation or acc.getdeviceptr "
886  "as defining op");
887  return success();
888 }
889 
890 template <typename Op>
891 static LogicalResult
892 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
893  mlir::OperandRange operands, llvm::StringRef operandName,
894  llvm::StringRef symbolName, bool checkOperandType = true) {
895  if (!operands.empty()) {
896  if (!attributes || attributes->size() != operands.size())
897  return op->emitOpError()
898  << "expected as many " << symbolName << " symbol reference as "
899  << operandName << " operands";
900  } else {
901  if (attributes)
902  return op->emitOpError()
903  << "unexpected " << symbolName << " symbol reference";
904  return success();
905  }
906 
908  for (auto args : llvm::zip(operands, *attributes)) {
909  mlir::Value operand = std::get<0>(args);
910 
911  if (!set.insert(operand).second)
912  return op->emitOpError()
913  << operandName << " operand appears more than once";
914 
915  mlir::Type varType = operand.getType();
916  auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
917  auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
918  if (!decl)
919  return op->emitOpError()
920  << "expected symbol reference " << symbolRef << " to point to a "
921  << operandName << " declaration";
922 
923  if (checkOperandType && decl.getType() && decl.getType() != varType)
924  return op->emitOpError() << "expected " << operandName << " (" << varType
925  << ") to be the same type as " << operandName
926  << " declaration (" << decl.getType() << ")";
927  }
928 
929  return success();
930 }
931 
932 unsigned ParallelOp::getNumDataOperands() {
933  return getReductionOperands().size() + getPrivateOperands().size() +
934  getFirstprivateOperands().size() + getDataClauseOperands().size();
935 }
936 
937 Value ParallelOp::getDataOperand(unsigned i) {
938  unsigned numOptional = getAsyncOperands().size();
939  numOptional += getNumGangs().size();
940  numOptional += getNumWorkers().size();
941  numOptional += getVectorLength().size();
942  numOptional += getIfCond() ? 1 : 0;
943  numOptional += getSelfCond() ? 1 : 0;
944  return getOperand(getWaitOperands().size() + numOptional + i);
945 }
946 
947 template <typename Op>
948 static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
949  ArrayAttr deviceTypes,
950  llvm::StringRef keyword) {
951  if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
952  return op.emitOpError() << keyword << " operands count must match "
953  << keyword << " device_type count";
954  return success();
955 }
956 
957 template <typename Op>
959  Op op, OperandRange operands, DenseI32ArrayAttr segments,
960  ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
961  std::size_t numOperandsInSegments = 0;
962  std::size_t nbOfSegments = 0;
963 
964  if (segments) {
965  for (auto segCount : segments.asArrayRef()) {
966  if (maxInSegment != 0 && segCount > maxInSegment)
967  return op.emitOpError() << keyword << " expects a maximum of "
968  << maxInSegment << " values per segment";
969  numOperandsInSegments += segCount;
970  ++nbOfSegments;
971  }
972  }
973 
974  if ((numOperandsInSegments != operands.size()) ||
975  (!deviceTypes && !operands.empty()))
976  return op.emitOpError()
977  << keyword << " operand count does not match count in segments";
978  if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
979  return op.emitOpError()
980  << keyword << " segment count does not match device_type count";
981  return success();
982 }
983 
984 LogicalResult acc::ParallelOp::verify() {
985  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
986  *this, getPrivatizations(), getPrivateOperands(), "private",
987  "privatizations", /*checkOperandType=*/false)))
988  return failure();
989  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
990  *this, getFirstprivatizations(), getFirstprivateOperands(),
991  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
992  return failure();
993  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
994  *this, getReductionRecipes(), getReductionOperands(), "reduction",
995  "reductions", false)))
996  return failure();
997 
999  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1000  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1001  return failure();
1002 
1004  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1005  getWaitOperandsDeviceTypeAttr(), "wait")))
1006  return failure();
1007 
1008  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1009  getNumWorkersDeviceTypeAttr(),
1010  "num_workers")))
1011  return failure();
1012 
1013  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1014  getVectorLengthDeviceTypeAttr(),
1015  "vector_length")))
1016  return failure();
1017 
1018  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1019  getAsyncOperandsDeviceTypeAttr(),
1020  "async")))
1021  return failure();
1022 
1023  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1024  return failure();
1025 
1026  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1027 }
1028 
1029 static mlir::Value
1030 getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1032  mlir::acc::DeviceType deviceType) {
1033  if (!arrayAttr)
1034  return {};
1035  if (auto pos = findSegment(*arrayAttr, deviceType))
1036  return range[*pos];
1037  return {};
1038 }
1039 
1040 bool acc::ParallelOp::hasAsyncOnly() {
1041  return hasAsyncOnly(mlir::acc::DeviceType::None);
1042 }
1043 
1044 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1045  return hasDeviceType(getAsyncOnly(), deviceType);
1046 }
1047 
1048 mlir::Value acc::ParallelOp::getAsyncValue() {
1049  return getAsyncValue(mlir::acc::DeviceType::None);
1050 }
1051 
1052 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1054  getAsyncOperands(), deviceType);
1055 }
1056 
1057 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1058  return getNumWorkersValue(mlir::acc::DeviceType::None);
1059 }
1060 
1062 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1063  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1064  deviceType);
1065 }
1066 
1067 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1068  return getVectorLengthValue(mlir::acc::DeviceType::None);
1069 }
1070 
1072 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1073  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1074  getVectorLength(), deviceType);
1075 }
1076 
1077 mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1078  return getNumGangsValues(mlir::acc::DeviceType::None);
1079 }
1080 
1082 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1083  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1084  getNumGangsSegments(), deviceType);
1085 }
1086 
1087 bool acc::ParallelOp::hasWaitOnly() {
1088  return hasWaitOnly(mlir::acc::DeviceType::None);
1089 }
1090 
1091 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1092  return hasDeviceType(getWaitOnly(), deviceType);
1093 }
1094 
1095 mlir::Operation::operand_range ParallelOp::getWaitValues() {
1096  return getWaitValues(mlir::acc::DeviceType::None);
1097 }
1098 
1100 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1102  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1103  getHasWaitDevnum(), deviceType);
1104 }
1105 
1106 mlir::Value ParallelOp::getWaitDevnum() {
1107  return getWaitDevnum(mlir::acc::DeviceType::None);
1108 }
1109 
1110 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1111  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1112  getWaitOperandsSegments(), getHasWaitDevnum(),
1113  deviceType);
1114 }
1115 
1116 void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1117  mlir::OperationState &odsState,
1118  mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1119  mlir::ValueRange vectorLength,
1120  mlir::ValueRange asyncOperands,
1121  mlir::ValueRange waitOperands, mlir::Value ifCond,
1122  mlir::Value selfCond, mlir::ValueRange reductionOperands,
1123  mlir::ValueRange gangPrivateOperands,
1124  mlir::ValueRange gangFirstPrivateOperands,
1125  mlir::ValueRange dataClauseOperands) {
1126 
1127  ParallelOp::build(
1128  odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1129  /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1130  /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1131  /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1132  /*numGangsDeviceType=*/nullptr, numWorkers,
1133  /*numWorkersDeviceType=*/nullptr, vectorLength,
1134  /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1135  /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1136  gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1137  /*firstprivatizations=*/nullptr, dataClauseOperands,
1138  /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1139 }
1140 
1141 static ParseResult parseNumGangs(
1142  mlir::OpAsmParser &parser,
1144  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1145  mlir::DenseI32ArrayAttr &segments) {
1148 
1149  do {
1150  if (failed(parser.parseLBrace()))
1151  return failure();
1152 
1153  int32_t crtOperandsSize = operands.size();
1154  if (failed(parser.parseCommaSeparatedList(
1156  if (parser.parseOperand(operands.emplace_back()) ||
1157  parser.parseColonType(types.emplace_back()))
1158  return failure();
1159  return success();
1160  })))
1161  return failure();
1162  seg.push_back(operands.size() - crtOperandsSize);
1163 
1164  if (failed(parser.parseRBrace()))
1165  return failure();
1166 
1167  if (succeeded(parser.parseOptionalLSquare())) {
1168  if (parser.parseAttribute(attributes.emplace_back()) ||
1169  parser.parseRSquare())
1170  return failure();
1171  } else {
1172  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1174  }
1175  } while (succeeded(parser.parseOptionalComma()));
1176 
1177  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1178  attributes.end());
1179  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1180  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1181 
1182  return success();
1183 }
1184 
1186  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1187  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1188  p << " [" << attr << "]";
1189 }
1190 
1192  mlir::OperandRange operands, mlir::TypeRange types,
1193  std::optional<mlir::ArrayAttr> deviceTypes,
1194  std::optional<mlir::DenseI32ArrayAttr> segments) {
1195  unsigned opIdx = 0;
1196  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1197  p << "{";
1198  llvm::interleaveComma(
1199  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1200  p << operands[opIdx] << " : " << operands[opIdx].getType();
1201  ++opIdx;
1202  });
1203  p << "}";
1204  printSingleDeviceType(p, it.value());
1205  });
1206 }
1207 
1209  mlir::OpAsmParser &parser,
1211  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1212  mlir::DenseI32ArrayAttr &segments) {
1215 
1216  do {
1217  if (failed(parser.parseLBrace()))
1218  return failure();
1219 
1220  int32_t crtOperandsSize = operands.size();
1221 
1222  if (failed(parser.parseCommaSeparatedList(
1224  if (parser.parseOperand(operands.emplace_back()) ||
1225  parser.parseColonType(types.emplace_back()))
1226  return failure();
1227  return success();
1228  })))
1229  return failure();
1230 
1231  seg.push_back(operands.size() - crtOperandsSize);
1232 
1233  if (failed(parser.parseRBrace()))
1234  return failure();
1235 
1236  if (succeeded(parser.parseOptionalLSquare())) {
1237  if (parser.parseAttribute(attributes.emplace_back()) ||
1238  parser.parseRSquare())
1239  return failure();
1240  } else {
1241  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1243  }
1244  } while (succeeded(parser.parseOptionalComma()));
1245 
1246  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1247  attributes.end());
1248  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1249  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1250 
1251  return success();
1252 }
1253 
1256  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1257  std::optional<mlir::DenseI32ArrayAttr> segments) {
1258  unsigned opIdx = 0;
1259  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1260  p << "{";
1261  llvm::interleaveComma(
1262  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1263  p << operands[opIdx] << " : " << operands[opIdx].getType();
1264  ++opIdx;
1265  });
1266  p << "}";
1267  printSingleDeviceType(p, it.value());
1268  });
1269 }
1270 
1271 static ParseResult parseWaitClause(
1272  mlir::OpAsmParser &parser,
1274  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1275  mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1276  mlir::ArrayAttr &keywordOnly) {
1277  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1279 
1280  bool needCommaBeforeOperands = false;
1281 
1282  // Keyword only
1283  if (failed(parser.parseOptionalLParen())) {
1284  keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1286  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1287  return success();
1288  }
1289 
1290  // Parse keyword only attributes
1291  if (succeeded(parser.parseOptionalLSquare())) {
1292  if (failed(parser.parseCommaSeparatedList([&]() {
1293  if (parser.parseAttribute(keywordAttrs.emplace_back()))
1294  return failure();
1295  return success();
1296  })))
1297  return failure();
1298  if (parser.parseRSquare())
1299  return failure();
1300  needCommaBeforeOperands = true;
1301  }
1302 
1303  if (needCommaBeforeOperands && failed(parser.parseComma()))
1304  return failure();
1305 
1306  do {
1307  if (failed(parser.parseLBrace()))
1308  return failure();
1309 
1310  int32_t crtOperandsSize = operands.size();
1311 
1312  if (succeeded(parser.parseOptionalKeyword("devnum"))) {
1313  if (failed(parser.parseColon()))
1314  return failure();
1315  devnum.push_back(BoolAttr::get(parser.getContext(), true));
1316  } else {
1317  devnum.push_back(BoolAttr::get(parser.getContext(), false));
1318  }
1319 
1320  if (failed(parser.parseCommaSeparatedList(
1322  if (parser.parseOperand(operands.emplace_back()) ||
1323  parser.parseColonType(types.emplace_back()))
1324  return failure();
1325  return success();
1326  })))
1327  return failure();
1328 
1329  seg.push_back(operands.size() - crtOperandsSize);
1330 
1331  if (failed(parser.parseRBrace()))
1332  return failure();
1333 
1334  if (succeeded(parser.parseOptionalLSquare())) {
1335  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
1336  parser.parseRSquare())
1337  return failure();
1338  } else {
1339  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1341  }
1342  } while (succeeded(parser.parseOptionalComma()));
1343 
1344  if (failed(parser.parseRParen()))
1345  return failure();
1346 
1347  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1348  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1349  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1350  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1351 
1352  return success();
1353 }
1354 
1355 static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1356  if (!hasDeviceTypeValues(attrs))
1357  return false;
1358  if (attrs->size() != 1)
1359  return false;
1360  if (auto deviceTypeAttr =
1361  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1362  return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1363  return false;
1364 }
1365 
1367  mlir::OperandRange operands, mlir::TypeRange types,
1368  std::optional<mlir::ArrayAttr> deviceTypes,
1369  std::optional<mlir::DenseI32ArrayAttr> segments,
1370  std::optional<mlir::ArrayAttr> hasDevNum,
1371  std::optional<mlir::ArrayAttr> keywordOnly) {
1372 
1373  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
1374  return;
1375 
1376  p << "(";
1377 
1378  printDeviceTypes(p, keywordOnly);
1379  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
1380  p << ", ";
1381 
1382  unsigned opIdx = 0;
1383  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1384  p << "{";
1385  auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1386  if (boolAttr && boolAttr.getValue())
1387  p << "devnum: ";
1388  llvm::interleaveComma(
1389  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1390  p << operands[opIdx] << " : " << operands[opIdx].getType();
1391  ++opIdx;
1392  });
1393  p << "}";
1394  printSingleDeviceType(p, it.value());
1395  });
1396 
1397  p << ")";
1398 }
1399 
1400 static ParseResult parseDeviceTypeOperands(
1401  mlir::OpAsmParser &parser,
1403  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1405  if (failed(parser.parseCommaSeparatedList([&]() {
1406  if (parser.parseOperand(operands.emplace_back()) ||
1407  parser.parseColonType(types.emplace_back()))
1408  return failure();
1409  if (succeeded(parser.parseOptionalLSquare())) {
1410  if (parser.parseAttribute(attributes.emplace_back()) ||
1411  parser.parseRSquare())
1412  return failure();
1413  } else {
1414  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1415  parser.getContext(), mlir::acc::DeviceType::None));
1416  }
1417  return success();
1418  })))
1419  return failure();
1420  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1421  attributes.end());
1422  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1423  return success();
1424 }
1425 
1426 static void
1428  mlir::OperandRange operands, mlir::TypeRange types,
1429  std::optional<mlir::ArrayAttr> deviceTypes) {
1430  if (!hasDeviceTypeValues(deviceTypes))
1431  return;
1432  llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1433  p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1434  printSingleDeviceType(p, std::get<0>(it));
1435  });
1436 }
1437 
1439  mlir::OpAsmParser &parser,
1441  llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1442  mlir::ArrayAttr &keywordOnlyDeviceType) {
1443 
1444  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1445  bool needCommaBeforeOperands = false;
1446 
1447  if (failed(parser.parseOptionalLParen())) {
1448  // Keyword only
1449  keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1451  keywordOnlyDeviceType =
1452  ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1453  return success();
1454  }
1455 
1456  // Parse keyword only attributes
1457  if (succeeded(parser.parseOptionalLSquare())) {
1458  // Parse keyword only attributes
1459  if (failed(parser.parseCommaSeparatedList([&]() {
1460  if (parser.parseAttribute(
1461  keywordOnlyDeviceTypeAttributes.emplace_back()))
1462  return failure();
1463  return success();
1464  })))
1465  return failure();
1466  if (parser.parseRSquare())
1467  return failure();
1468  needCommaBeforeOperands = true;
1469  }
1470 
1471  if (needCommaBeforeOperands && failed(parser.parseComma()))
1472  return failure();
1473 
1475  if (failed(parser.parseCommaSeparatedList([&]() {
1476  if (parser.parseOperand(operands.emplace_back()) ||
1477  parser.parseColonType(types.emplace_back()))
1478  return failure();
1479  if (succeeded(parser.parseOptionalLSquare())) {
1480  if (parser.parseAttribute(attributes.emplace_back()) ||
1481  parser.parseRSquare())
1482  return failure();
1483  } else {
1484  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1485  parser.getContext(), mlir::acc::DeviceType::None));
1486  }
1487  return success();
1488  })))
1489  return failure();
1490 
1491  if (failed(parser.parseRParen()))
1492  return failure();
1493 
1494  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1495  attributes.end());
1496  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1497  return success();
1498 }
1499 
1502  mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1503  std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1504 
1505  if (operands.begin() == operands.end() &&
1506  hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
1507  return;
1508  }
1509 
1510  p << "(";
1511  printDeviceTypes(p, keywordOnlyDeviceTypes);
1512  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
1513  hasDeviceTypeValues(deviceTypes))
1514  p << ", ";
1515  printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1516  p << ")";
1517 }
1518 
1519 static ParseResult
1521  mlir::acc::CombinedConstructsTypeAttr &attr) {
1522  if (succeeded(parser.parseOptionalKeyword("combined"))) {
1523  if (parser.parseLParen())
1524  return failure();
1525  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
1527  parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1528  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
1530  parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1531  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
1533  parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1534  } else {
1535  parser.emitError(parser.getCurrentLocation(),
1536  "expected compute construct name");
1537  return failure();
1538  }
1539  if (parser.parseRParen())
1540  return failure();
1541  }
1542  return success();
1543 }
1544 
1545 static void
1547  mlir::acc::CombinedConstructsTypeAttr attr) {
1548  if (attr) {
1549  switch (attr.getValue()) {
1550  case mlir::acc::CombinedConstructsType::KernelsLoop:
1551  p << "combined(kernels)";
1552  break;
1553  case mlir::acc::CombinedConstructsType::ParallelLoop:
1554  p << "combined(parallel)";
1555  break;
1556  case mlir::acc::CombinedConstructsType::SerialLoop:
1557  p << "combined(serial)";
1558  break;
1559  };
1560  }
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // SerialOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 unsigned SerialOp::getNumDataOperands() {
1568  return getReductionOperands().size() + getPrivateOperands().size() +
1569  getFirstprivateOperands().size() + getDataClauseOperands().size();
1570 }
1571 
1572 Value SerialOp::getDataOperand(unsigned i) {
1573  unsigned numOptional = getAsyncOperands().size();
1574  numOptional += getIfCond() ? 1 : 0;
1575  numOptional += getSelfCond() ? 1 : 0;
1576  return getOperand(getWaitOperands().size() + numOptional + i);
1577 }
1578 
1579 bool acc::SerialOp::hasAsyncOnly() {
1580  return hasAsyncOnly(mlir::acc::DeviceType::None);
1581 }
1582 
1583 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1584  return hasDeviceType(getAsyncOnly(), deviceType);
1585 }
1586 
1587 mlir::Value acc::SerialOp::getAsyncValue() {
1588  return getAsyncValue(mlir::acc::DeviceType::None);
1589 }
1590 
1591 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1593  getAsyncOperands(), deviceType);
1594 }
1595 
1596 bool acc::SerialOp::hasWaitOnly() {
1597  return hasWaitOnly(mlir::acc::DeviceType::None);
1598 }
1599 
1600 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1601  return hasDeviceType(getWaitOnly(), deviceType);
1602 }
1603 
1604 mlir::Operation::operand_range SerialOp::getWaitValues() {
1605  return getWaitValues(mlir::acc::DeviceType::None);
1606 }
1607 
1609 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1611  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1612  getHasWaitDevnum(), deviceType);
1613 }
1614 
1615 mlir::Value SerialOp::getWaitDevnum() {
1616  return getWaitDevnum(mlir::acc::DeviceType::None);
1617 }
1618 
1619 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1620  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1621  getWaitOperandsSegments(), getHasWaitDevnum(),
1622  deviceType);
1623 }
1624 
1625 LogicalResult acc::SerialOp::verify() {
1626  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1627  *this, getPrivatizations(), getPrivateOperands(), "private",
1628  "privatizations", /*checkOperandType=*/false)))
1629  return failure();
1630  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1631  *this, getFirstprivatizations(), getFirstprivateOperands(),
1632  "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1633  return failure();
1634  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1635  *this, getReductionRecipes(), getReductionOperands(), "reduction",
1636  "reductions", false)))
1637  return failure();
1638 
1640  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1641  getWaitOperandsDeviceTypeAttr(), "wait")))
1642  return failure();
1643 
1644  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1645  getAsyncOperandsDeviceTypeAttr(),
1646  "async")))
1647  return failure();
1648 
1649  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1650  return failure();
1651 
1652  return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // KernelsOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 unsigned KernelsOp::getNumDataOperands() {
1660  return getDataClauseOperands().size();
1661 }
1662 
1663 Value KernelsOp::getDataOperand(unsigned i) {
1664  unsigned numOptional = getAsyncOperands().size();
1665  numOptional += getWaitOperands().size();
1666  numOptional += getNumGangs().size();
1667  numOptional += getNumWorkers().size();
1668  numOptional += getVectorLength().size();
1669  numOptional += getIfCond() ? 1 : 0;
1670  numOptional += getSelfCond() ? 1 : 0;
1671  return getOperand(numOptional + i);
1672 }
1673 
1674 bool acc::KernelsOp::hasAsyncOnly() {
1675  return hasAsyncOnly(mlir::acc::DeviceType::None);
1676 }
1677 
1678 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1679  return hasDeviceType(getAsyncOnly(), deviceType);
1680 }
1681 
1682 mlir::Value acc::KernelsOp::getAsyncValue() {
1683  return getAsyncValue(mlir::acc::DeviceType::None);
1684 }
1685 
1686 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1688  getAsyncOperands(), deviceType);
1689 }
1690 
1691 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1692  return getNumWorkersValue(mlir::acc::DeviceType::None);
1693 }
1694 
1696 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1697  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1698  deviceType);
1699 }
1700 
1701 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1702  return getVectorLengthValue(mlir::acc::DeviceType::None);
1703 }
1704 
1706 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1707  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1708  getVectorLength(), deviceType);
1709 }
1710 
1711 mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1712  return getNumGangsValues(mlir::acc::DeviceType::None);
1713 }
1714 
1716 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1717  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1718  getNumGangsSegments(), deviceType);
1719 }
1720 
1721 bool acc::KernelsOp::hasWaitOnly() {
1722  return hasWaitOnly(mlir::acc::DeviceType::None);
1723 }
1724 
1725 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1726  return hasDeviceType(getWaitOnly(), deviceType);
1727 }
1728 
1729 mlir::Operation::operand_range KernelsOp::getWaitValues() {
1730  return getWaitValues(mlir::acc::DeviceType::None);
1731 }
1732 
1734 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1736  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1737  getHasWaitDevnum(), deviceType);
1738 }
1739 
1740 mlir::Value KernelsOp::getWaitDevnum() {
1741  return getWaitDevnum(mlir::acc::DeviceType::None);
1742 }
1743 
1744 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1745  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1746  getWaitOperandsSegments(), getHasWaitDevnum(),
1747  deviceType);
1748 }
1749 
1750 LogicalResult acc::KernelsOp::verify() {
1752  *this, getNumGangs(), getNumGangsSegmentsAttr(),
1753  getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1754  return failure();
1755 
1757  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1758  getWaitOperandsDeviceTypeAttr(), "wait")))
1759  return failure();
1760 
1761  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1762  getNumWorkersDeviceTypeAttr(),
1763  "num_workers")))
1764  return failure();
1765 
1766  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1767  getVectorLengthDeviceTypeAttr(),
1768  "vector_length")))
1769  return failure();
1770 
1771  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1772  getAsyncOperandsDeviceTypeAttr(),
1773  "async")))
1774  return failure();
1775 
1776  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1777  return failure();
1778 
1779  return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1780 }
1781 
1782 //===----------------------------------------------------------------------===//
1783 // HostDataOp
1784 //===----------------------------------------------------------------------===//
1785 
1786 LogicalResult acc::HostDataOp::verify() {
1787  if (getDataClauseOperands().empty())
1788  return emitError("at least one operand must appear on the host_data "
1789  "operation");
1790 
1791  for (mlir::Value operand : getDataClauseOperands())
1792  if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1793  return emitError("expect data entry operation as defining op");
1794  return success();
1795 }
1796 
1797 void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1798  MLIRContext *context) {
1799  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // LoopOp
1804 //===----------------------------------------------------------------------===//
1805 
1806 static ParseResult parseGangValue(
1807  OpAsmParser &parser, llvm::StringRef keyword,
1810  llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1811  bool &needCommaBetweenValues, bool &newValue) {
1812  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1813  if (parser.parseEqual())
1814  return failure();
1815  if (parser.parseOperand(operands.emplace_back()) ||
1816  parser.parseColonType(types.emplace_back()))
1817  return failure();
1818  attributes.push_back(gangArgType);
1819  needCommaBetweenValues = true;
1820  newValue = true;
1821  }
1822  return success();
1823 }
1824 
1825 static ParseResult parseGangClause(
1826  OpAsmParser &parser,
1828  llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1829  mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1830  mlir::ArrayAttr &gangOnlyDeviceType) {
1831  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1832  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1833  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1835  bool needCommaBetweenValues = false;
1836  bool needCommaBeforeOperands = false;
1837 
1838  if (failed(parser.parseOptionalLParen())) {
1839  // Gang only keyword
1840  gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1842  gangOnlyDeviceType =
1843  ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1844  return success();
1845  }
1846 
1847  // Parse gang only attributes
1848  if (succeeded(parser.parseOptionalLSquare())) {
1849  // Parse gang only attributes
1850  if (failed(parser.parseCommaSeparatedList([&]() {
1851  if (parser.parseAttribute(
1852  gangOnlyDeviceTypeAttributes.emplace_back()))
1853  return failure();
1854  return success();
1855  })))
1856  return failure();
1857  if (parser.parseRSquare())
1858  return failure();
1859  needCommaBeforeOperands = true;
1860  }
1861 
1862  auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1863  mlir::acc::GangArgType::Num);
1864  auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1865  mlir::acc::GangArgType::Dim);
1866  auto argStatic = mlir::acc::GangArgTypeAttr::get(
1867  parser.getContext(), mlir::acc::GangArgType::Static);
1868 
1869  do {
1870  if (needCommaBeforeOperands) {
1871  needCommaBeforeOperands = false;
1872  continue;
1873  }
1874 
1875  if (failed(parser.parseLBrace()))
1876  return failure();
1877 
1878  int32_t crtOperandsSize = gangOperands.size();
1879  while (true) {
1880  bool newValue = false;
1881  bool needValue = false;
1882  if (needCommaBetweenValues) {
1883  if (succeeded(parser.parseOptionalComma()))
1884  needValue = true; // expect a new value after comma.
1885  else
1886  break;
1887  }
1888 
1889  if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1890  gangOperands, gangOperandsType,
1891  gangArgTypeAttributes, argNum,
1892  needCommaBetweenValues, newValue)))
1893  return failure();
1894  if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1895  gangOperands, gangOperandsType,
1896  gangArgTypeAttributes, argDim,
1897  needCommaBetweenValues, newValue)))
1898  return failure();
1899  if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1900  gangOperands, gangOperandsType,
1901  gangArgTypeAttributes, argStatic,
1902  needCommaBetweenValues, newValue)))
1903  return failure();
1904 
1905  if (!newValue && needValue) {
1906  parser.emitError(parser.getCurrentLocation(),
1907  "new value expected after comma");
1908  return failure();
1909  }
1910 
1911  if (!newValue)
1912  break;
1913  }
1914 
1915  if (gangOperands.empty())
1916  return parser.emitError(
1917  parser.getCurrentLocation(),
1918  "expect at least one of num, dim or static values");
1919 
1920  if (failed(parser.parseRBrace()))
1921  return failure();
1922 
1923  if (succeeded(parser.parseOptionalLSquare())) {
1924  if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
1925  parser.parseRSquare())
1926  return failure();
1927  } else {
1928  deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1930  }
1931 
1932  seg.push_back(gangOperands.size() - crtOperandsSize);
1933 
1934  } while (succeeded(parser.parseOptionalComma()));
1935 
1936  if (failed(parser.parseRParen()))
1937  return failure();
1938 
1939  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1940  gangArgTypeAttributes.end());
1941  gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1942  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1943 
1945  gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1946  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1947 
1948  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1949  return success();
1950 }
1951 
1953  mlir::OperandRange operands, mlir::TypeRange types,
1954  std::optional<mlir::ArrayAttr> gangArgTypes,
1955  std::optional<mlir::ArrayAttr> deviceTypes,
1956  std::optional<mlir::DenseI32ArrayAttr> segments,
1957  std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1958 
1959  if (operands.begin() == operands.end() &&
1960  hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
1961  return;
1962  }
1963 
1964  p << "(";
1965 
1966  printDeviceTypes(p, gangOnlyDeviceTypes);
1967 
1968  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
1969  hasDeviceTypeValues(deviceTypes))
1970  p << ", ";
1971 
1972  if (hasDeviceTypeValues(deviceTypes)) {
1973  unsigned opIdx = 0;
1974  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1975  p << "{";
1976  llvm::interleaveComma(
1977  llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1978  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1979  (*gangArgTypes)[opIdx]);
1980  if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1981  p << LoopOp::getGangNumKeyword();
1982  else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1983  p << LoopOp::getGangDimKeyword();
1984  else if (gangArgTypeAttr.getValue() ==
1985  mlir::acc::GangArgType::Static)
1986  p << LoopOp::getGangStaticKeyword();
1987  p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1988  ++opIdx;
1989  });
1990  p << "}";
1991  printSingleDeviceType(p, it.value());
1992  });
1993  }
1994  p << ")";
1995 }
1996 
1998  std::optional<mlir::ArrayAttr> segments,
1999  llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2000  if (!segments)
2001  return false;
2002  for (auto attr : *segments) {
2003  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2004  if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2005  return true;
2006  }
2007  return false;
2008 }
2009 
2010 /// Check for duplicates in the DeviceType array attribute.
2011 LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2012  llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2013  if (!deviceTypes)
2014  return success();
2015  for (auto attr : deviceTypes) {
2016  auto deviceTypeAttr =
2017  mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2018  if (!deviceTypeAttr)
2019  return failure();
2020  if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2021  return failure();
2022  }
2023  return success();
2024 }
2025 
2026 LogicalResult acc::LoopOp::verify() {
2027  if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2028  (getUpperbound().size() != getInclusiveUpperbound()->size()))
2029  return emitError() << "inclusiveUpperbound size is expected to be the same"
2030  << " as upperbound size";
2031 
2032  // Check collapse
2033  if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2034  return emitOpError() << "collapse device_type attr must be define when"
2035  << " collapse attr is present";
2036 
2037  if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2038  getCollapseAttr().getValue().size() !=
2039  getCollapseDeviceTypeAttr().getValue().size())
2040  return emitOpError() << "collapse attribute count must match collapse"
2041  << " device_type count";
2042  if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2043  return emitOpError()
2044  << "duplicate device_type found in collapseDeviceType attribute";
2045 
2046  // Check gang
2047  if (!getGangOperands().empty()) {
2048  if (!getGangOperandsArgType())
2049  return emitOpError() << "gangOperandsArgType attribute must be defined"
2050  << " when gang operands are present";
2051 
2052  if (getGangOperands().size() !=
2053  getGangOperandsArgTypeAttr().getValue().size())
2054  return emitOpError() << "gangOperandsArgType attribute count must match"
2055  << " gangOperands count";
2056  }
2057  if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2058  return emitOpError() << "duplicate device_type found in gang attribute";
2059 
2061  *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2062  getGangOperandsDeviceTypeAttr(), "gang")))
2063  return failure();
2064 
2065  // Check worker
2066  if (failed(checkDeviceTypes(getWorkerAttr())))
2067  return emitOpError() << "duplicate device_type found in worker attribute";
2068  if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2069  return emitOpError() << "duplicate device_type found in "
2070  "workerNumOperandsDeviceType attribute";
2071  if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2072  getWorkerNumOperandsDeviceTypeAttr(),
2073  "worker")))
2074  return failure();
2075 
2076  // Check vector
2077  if (failed(checkDeviceTypes(getVectorAttr())))
2078  return emitOpError() << "duplicate device_type found in vector attribute";
2079  if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2080  return emitOpError() << "duplicate device_type found in "
2081  "vectorOperandsDeviceType attribute";
2082  if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2083  getVectorOperandsDeviceTypeAttr(),
2084  "vector")))
2085  return failure();
2086 
2088  *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2089  getTileOperandsDeviceTypeAttr(), "tile")))
2090  return failure();
2091 
2092  // auto, independent and seq attribute are mutually exclusive.
2093  llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2094  if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2095  hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2096  hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2097  return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
2098  << "\", " << getIndependentAttrName() << ", "
2099  << getSeqAttrName()
2100  << " can be present at the same time";
2101  }
2102 
2103  // Gang, worker and vector are incompatible with seq.
2104  if (getSeqAttr()) {
2105  for (auto attr : getSeqAttr()) {
2106  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2107  if (hasVector(deviceTypeAttr.getValue()) ||
2108  getVectorValue(deviceTypeAttr.getValue()) ||
2109  hasWorker(deviceTypeAttr.getValue()) ||
2110  getWorkerValue(deviceTypeAttr.getValue()) ||
2111  hasGang(deviceTypeAttr.getValue()) ||
2112  getGangValue(mlir::acc::GangArgType::Num,
2113  deviceTypeAttr.getValue()) ||
2114  getGangValue(mlir::acc::GangArgType::Dim,
2115  deviceTypeAttr.getValue()) ||
2116  getGangValue(mlir::acc::GangArgType::Static,
2117  deviceTypeAttr.getValue()))
2118  return emitError()
2119  << "gang, worker or vector cannot appear with the seq attr";
2120  }
2121  }
2122 
2123  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2124  *this, getPrivatizations(), getPrivateOperands(), "private",
2125  "privatizations", false)))
2126  return failure();
2127 
2128  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2129  *this, getReductionRecipes(), getReductionOperands(), "reduction",
2130  "reductions", false)))
2131  return failure();
2132 
2133  if (getCombined().has_value() &&
2134  (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2135  getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2136  getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2137  return emitError("unexpected combined constructs attribute");
2138  }
2139 
2140  // Check non-empty body().
2141  if (getRegion().empty())
2142  return emitError("expected non-empty body.");
2143 
2144  return success();
2145 }
2146 
2147 unsigned LoopOp::getNumDataOperands() {
2148  return getReductionOperands().size() + getPrivateOperands().size();
2149 }
2150 
2151 Value LoopOp::getDataOperand(unsigned i) {
2152  unsigned numOptional =
2153  getLowerbound().size() + getUpperbound().size() + getStep().size();
2154  numOptional += getGangOperands().size();
2155  numOptional += getVectorOperands().size();
2156  numOptional += getWorkerNumOperands().size();
2157  numOptional += getTileOperands().size();
2158  numOptional += getCacheOperands().size();
2159  return getOperand(numOptional + i);
2160 }
2161 
2162 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2163 
2164 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2165  return hasDeviceType(getAuto_(), deviceType);
2166 }
2167 
2168 bool LoopOp::hasIndependent() {
2169  return hasIndependent(mlir::acc::DeviceType::None);
2170 }
2171 
2172 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2173  return hasDeviceType(getIndependent(), deviceType);
2174 }
2175 
2176 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2177 
2178 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2179  return hasDeviceType(getSeq(), deviceType);
2180 }
2181 
2182 mlir::Value LoopOp::getVectorValue() {
2183  return getVectorValue(mlir::acc::DeviceType::None);
2184 }
2185 
2186 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2187  return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2188  getVectorOperands(), deviceType);
2189 }
2190 
2191 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2192 
2193 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2194  return hasDeviceType(getVector(), deviceType);
2195 }
2196 
2197 mlir::Value LoopOp::getWorkerValue() {
2198  return getWorkerValue(mlir::acc::DeviceType::None);
2199 }
2200 
2201 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2202  return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2203  getWorkerNumOperands(), deviceType);
2204 }
2205 
2206 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2207 
2208 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2209  return hasDeviceType(getWorker(), deviceType);
2210 }
2211 
2212 mlir::Operation::operand_range LoopOp::getTileValues() {
2213  return getTileValues(mlir::acc::DeviceType::None);
2214 }
2215 
2217 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2218  return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2219  getTileOperandsSegments(), deviceType);
2220 }
2221 
2222 std::optional<int64_t> LoopOp::getCollapseValue() {
2223  return getCollapseValue(mlir::acc::DeviceType::None);
2224 }
2225 
2226 std::optional<int64_t>
2227 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2228  if (!getCollapseAttr())
2229  return std::nullopt;
2230  if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2231  auto intAttr =
2232  mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2233  return intAttr.getValue().getZExtValue();
2234  }
2235  return std::nullopt;
2236 }
2237 
2238 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2239  return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2240 }
2241 
2242 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2243  mlir::acc::DeviceType deviceType) {
2244  if (getGangOperands().empty())
2245  return {};
2246  if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2247  int32_t nbOperandsBefore = 0;
2248  for (unsigned i = 0; i < *pos; ++i)
2249  nbOperandsBefore += (*getGangOperandsSegments())[i];
2251  getGangOperands()
2252  .drop_front(nbOperandsBefore)
2253  .take_front((*getGangOperandsSegments())[*pos]);
2254 
2255  int32_t argTypeIdx = nbOperandsBefore;
2256  for (auto value : values) {
2257  auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2258  (*getGangOperandsArgType())[argTypeIdx]);
2259  if (gangArgTypeAttr.getValue() == gangArgType)
2260  return value;
2261  ++argTypeIdx;
2262  }
2263  }
2264  return {};
2265 }
2266 
2267 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2268 
2269 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2270  return hasDeviceType(getGang(), deviceType);
2271 }
2272 
2273 llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2274  return {&getRegion()};
2275 }
2276 
2277 /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2278 /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2279 /// `(` ssa-id-and-type-list `)`
2280 /// region
2281 ParseResult
2284  SmallVectorImpl<Type> &lowerboundType,
2286  SmallVectorImpl<Type> &upperboundType,
2288  SmallVectorImpl<Type> &stepType) {
2289 
2290  SmallVector<OpAsmParser::Argument> inductionVars;
2291  if (succeeded(
2292  parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2293  if (parser.parseLParen() ||
2294  parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
2295  /*allowType=*/true) ||
2296  parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2297  parser.parseOperandList(lowerbound, inductionVars.size(),
2299  parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
2300  parser.parseKeyword("to") || parser.parseLParen() ||
2301  parser.parseOperandList(upperbound, inductionVars.size(),
2303  parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
2304  parser.parseKeyword("step") || parser.parseLParen() ||
2305  parser.parseOperandList(step, inductionVars.size(),
2307  parser.parseColonTypeList(stepType) || parser.parseRParen())
2308  return failure();
2309  }
2310  return parser.parseRegion(region, inductionVars);
2311 }
2312 
2314  ValueRange lowerbound, TypeRange lowerboundType,
2315  ValueRange upperbound, TypeRange upperboundType,
2316  ValueRange steps, TypeRange stepType) {
2317  ValueRange regionArgs = region.front().getArguments();
2318  if (!regionArgs.empty()) {
2319  p << acc::LoopOp::getControlKeyword() << "(";
2320  llvm::interleaveComma(regionArgs, p,
2321  [&p](Value v) { p << v << " : " << v.getType(); });
2322  p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2323  << upperbound << " : " << upperboundType << ") " << " step (" << steps
2324  << " : " << stepType << ") ";
2325  }
2326  p.printRegion(region, /*printEntryBlockArgs=*/false);
2327 }
2328 
2329 //===----------------------------------------------------------------------===//
2330 // DataOp
2331 //===----------------------------------------------------------------------===//
2332 
2333 LogicalResult acc::DataOp::verify() {
2334  // 2.6.5. Data Construct restriction
2335  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2336  // attach, or default clause must appear on a data construct.
2337  if (getOperands().empty() && !getDefaultAttr())
2338  return emitError("at least one operand or the default attribute "
2339  "must appear on the data operation");
2340 
2341  for (mlir::Value operand : getDataClauseOperands())
2342  if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2343  acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2344  acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2345  operand.getDefiningOp()))
2346  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2347  "as defining op");
2348 
2349  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2350  return failure();
2351 
2352  return success();
2353 }
2354 
2355 unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2356 
2357 Value DataOp::getDataOperand(unsigned i) {
2358  unsigned numOptional = getIfCond() ? 1 : 0;
2359  numOptional += getAsyncOperands().size() ? 1 : 0;
2360  numOptional += getWaitOperands().size();
2361  return getOperand(numOptional + i);
2362 }
2363 
2364 bool acc::DataOp::hasAsyncOnly() {
2365  return hasAsyncOnly(mlir::acc::DeviceType::None);
2366 }
2367 
2368 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2369  return hasDeviceType(getAsyncOnly(), deviceType);
2370 }
2371 
2372 mlir::Value DataOp::getAsyncValue() {
2373  return getAsyncValue(mlir::acc::DeviceType::None);
2374 }
2375 
2376 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2378  getAsyncOperands(), deviceType);
2379 }
2380 
2381 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2382 
2383 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2384  return hasDeviceType(getWaitOnly(), deviceType);
2385 }
2386 
2387 mlir::Operation::operand_range DataOp::getWaitValues() {
2388  return getWaitValues(mlir::acc::DeviceType::None);
2389 }
2390 
2392 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2394  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2395  getHasWaitDevnum(), deviceType);
2396 }
2397 
2398 mlir::Value DataOp::getWaitDevnum() {
2399  return getWaitDevnum(mlir::acc::DeviceType::None);
2400 }
2401 
2402 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2403  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2404  getWaitOperandsSegments(), getHasWaitDevnum(),
2405  deviceType);
2406 }
2407 
2408 //===----------------------------------------------------------------------===//
2409 // ExitDataOp
2410 //===----------------------------------------------------------------------===//
2411 
2412 LogicalResult acc::ExitDataOp::verify() {
2413  // 2.6.6. Data Exit Directive restriction
2414  // At least one copyout, delete, or detach clause must appear on an exit data
2415  // directive.
2416  if (getDataClauseOperands().empty())
2417  return emitError("at least one operand must be present in dataOperands on "
2418  "the exit data operation");
2419 
2420  // The async attribute represent the async clause without value. Therefore the
2421  // attribute and operand cannot appear at the same time.
2422  if (getAsyncOperand() && getAsync())
2423  return emitError("async attribute cannot appear with asyncOperand");
2424 
2425  // The wait attribute represent the wait clause without values. Therefore the
2426  // attribute and operands cannot appear at the same time.
2427  if (!getWaitOperands().empty() && getWait())
2428  return emitError("wait attribute cannot appear with waitOperands");
2429 
2430  if (getWaitDevnum() && getWaitOperands().empty())
2431  return emitError("wait_devnum cannot appear without waitOperands");
2432 
2433  return success();
2434 }
2435 
2436 unsigned ExitDataOp::getNumDataOperands() {
2437  return getDataClauseOperands().size();
2438 }
2439 
2440 Value ExitDataOp::getDataOperand(unsigned i) {
2441  unsigned numOptional = getIfCond() ? 1 : 0;
2442  numOptional += getAsyncOperand() ? 1 : 0;
2443  numOptional += getWaitDevnum() ? 1 : 0;
2444  return getOperand(getWaitOperands().size() + numOptional + i);
2445 }
2446 
2447 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2448  MLIRContext *context) {
2449  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2450 }
2451 
2452 //===----------------------------------------------------------------------===//
2453 // EnterDataOp
2454 //===----------------------------------------------------------------------===//
2455 
2456 LogicalResult acc::EnterDataOp::verify() {
2457  // 2.6.6. Data Enter Directive restriction
2458  // At least one copyin, create, or attach clause must appear on an enter data
2459  // directive.
2460  if (getDataClauseOperands().empty())
2461  return emitError("at least one operand must be present in dataOperands on "
2462  "the enter data operation");
2463 
2464  // The async attribute represent the async clause without value. Therefore the
2465  // attribute and operand cannot appear at the same time.
2466  if (getAsyncOperand() && getAsync())
2467  return emitError("async attribute cannot appear with asyncOperand");
2468 
2469  // The wait attribute represent the wait clause without values. Therefore the
2470  // attribute and operands cannot appear at the same time.
2471  if (!getWaitOperands().empty() && getWait())
2472  return emitError("wait attribute cannot appear with waitOperands");
2473 
2474  if (getWaitDevnum() && getWaitOperands().empty())
2475  return emitError("wait_devnum cannot appear without waitOperands");
2476 
2477  for (mlir::Value operand : getDataClauseOperands())
2478  if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2479  operand.getDefiningOp()))
2480  return emitError("expect data entry operation as defining op");
2481 
2482  return success();
2483 }
2484 
2485 unsigned EnterDataOp::getNumDataOperands() {
2486  return getDataClauseOperands().size();
2487 }
2488 
2489 Value EnterDataOp::getDataOperand(unsigned i) {
2490  unsigned numOptional = getIfCond() ? 1 : 0;
2491  numOptional += getAsyncOperand() ? 1 : 0;
2492  numOptional += getWaitDevnum() ? 1 : 0;
2493  return getOperand(getWaitOperands().size() + numOptional + i);
2494 }
2495 
2496 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2497  MLIRContext *context) {
2498  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2499 }
2500 
2501 //===----------------------------------------------------------------------===//
2502 // AtomicReadOp
2503 //===----------------------------------------------------------------------===//
2504 
2505 LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2506 
2507 //===----------------------------------------------------------------------===//
2508 // AtomicWriteOp
2509 //===----------------------------------------------------------------------===//
2510 
2511 LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2512 
2513 //===----------------------------------------------------------------------===//
2514 // AtomicUpdateOp
2515 //===----------------------------------------------------------------------===//
2516 
2517 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2518  PatternRewriter &rewriter) {
2519  if (op.isNoOp()) {
2520  rewriter.eraseOp(op);
2521  return success();
2522  }
2523 
2524  if (Value writeVal = op.getWriteOpVal()) {
2525  rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2526  return success();
2527  }
2528 
2529  return failure();
2530 }
2531 
2532 LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2533 
2534 LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2535 
2536 //===----------------------------------------------------------------------===//
2537 // AtomicCaptureOp
2538 //===----------------------------------------------------------------------===//
2539 
2540 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2541  if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2542  return op;
2543  return dyn_cast<AtomicReadOp>(getSecondOp());
2544 }
2545 
2546 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2547  if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2548  return op;
2549  return dyn_cast<AtomicWriteOp>(getSecondOp());
2550 }
2551 
2552 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2553  if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2554  return op;
2555  return dyn_cast<AtomicUpdateOp>(getSecondOp());
2556 }
2557 
2558 LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2559 
2560 //===----------------------------------------------------------------------===//
2561 // DeclareEnterOp
2562 //===----------------------------------------------------------------------===//
2563 
2564 template <typename Op>
2565 static LogicalResult
2567  bool requireAtLeastOneOperand = true) {
2568  if (operands.empty() && requireAtLeastOneOperand)
2569  return emitError(
2570  op->getLoc(),
2571  "at least one operand must appear on the declare operation");
2572 
2573  for (mlir::Value operand : operands) {
2574  if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2575  acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2576  acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2577  operand.getDefiningOp()))
2578  return op.emitError(
2579  "expect valid declare data entry operation or acc.getdeviceptr "
2580  "as defining op");
2581 
2582  mlir::Value varPtr{getVarPtr(operand.getDefiningOp())};
2583  assert(varPtr && "declare operands can only be data entry operations which "
2584  "must have varPtr");
2585  std::optional<mlir::acc::DataClause> dataClauseOptional{
2586  getDataClause(operand.getDefiningOp())};
2587  assert(dataClauseOptional.has_value() &&
2588  "declare operands can only be data entry operations which must have "
2589  "dataClause");
2590 
2591  // If varPtr has no defining op - there is nothing to check further.
2592  if (!varPtr.getDefiningOp())
2593  continue;
2594 
2595  // Check that the varPtr has a declare attribute.
2596  auto declareAttribute{
2597  varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())};
2598  if (!declareAttribute)
2599  return op.emitError(
2600  "expect declare attribute on variable in declare operation");
2601 
2602  auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2603  if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2604  return op.emitError(
2605  "expect matching declare attribute on variable in declare operation");
2606 
2607  // If the variable is marked with implicit attribute, the matching declare
2608  // data action must also be marked implicit. The reverse is not checked
2609  // since implicit data action may be inserted to do actions like updating
2610  // device copy, in which case the variable is not necessarily implicitly
2611  // declare'd.
2612  if (declAttr.getImplicit() &&
2613  declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp()))
2614  return op.emitError(
2615  "implicitness must match between declare op and flag on variable");
2616  }
2617 
2618  return success();
2619 }
2620 
2621 LogicalResult acc::DeclareEnterOp::verify() {
2622  return checkDeclareOperands(*this, this->getDataClauseOperands());
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // DeclareExitOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 LogicalResult acc::DeclareExitOp::verify() {
2630  if (getToken())
2631  return checkDeclareOperands(*this, this->getDataClauseOperands(),
2632  /*requireAtLeastOneOperand=*/false);
2633  return checkDeclareOperands(*this, this->getDataClauseOperands());
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // DeclareOp
2638 //===----------------------------------------------------------------------===//
2639 
2640 LogicalResult acc::DeclareOp::verify() {
2641  return checkDeclareOperands(*this, this->getDataClauseOperands());
2642 }
2643 
2644 //===----------------------------------------------------------------------===//
2645 // RoutineOp
2646 //===----------------------------------------------------------------------===//
2647 
2648 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2649  acc::DeviceType dtype) {
2650  unsigned parallelism = 0;
2651  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2652  parallelism += op.hasWorker(dtype) ? 1 : 0;
2653  parallelism += op.hasVector(dtype) ? 1 : 0;
2654  parallelism += op.hasSeq(dtype) ? 1 : 0;
2655  return parallelism;
2656 }
2657 
2658 LogicalResult acc::RoutineOp::verify() {
2659  unsigned baseParallelism =
2661 
2662  if (baseParallelism > 1)
2663  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2664  "be present at the same time";
2665 
2666  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2667  ++dtypeInt) {
2668  auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2669  if (dtype == acc::DeviceType::None)
2670  continue;
2671  unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2672 
2673  if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2674  return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2675  "be present at the same time";
2676  }
2677 
2678  return success();
2679 }
2680 
2681 static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2682  mlir::ArrayAttr &deviceTypes) {
2683  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2684  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2685 
2686  if (failed(parser.parseCommaSeparatedList([&]() {
2687  if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2688  return failure();
2689  if (failed(parser.parseOptionalLSquare())) {
2690  deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2691  parser.getContext(), mlir::acc::DeviceType::None));
2692  } else {
2693  if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2694  parser.parseRSquare())
2695  return failure();
2696  }
2697  return success();
2698  })))
2699  return failure();
2700 
2701  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2702  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2703 
2704  return success();
2705 }
2706 
2708  std::optional<mlir::ArrayAttr> bindName,
2709  std::optional<mlir::ArrayAttr> deviceTypes) {
2710  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2711  [&](const auto &pair) {
2712  p << std::get<0>(pair);
2713  printSingleDeviceType(p, std::get<1>(pair));
2714  });
2715 }
2716 
2717 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2718  mlir::ArrayAttr &gang,
2719  mlir::ArrayAttr &gangDim,
2720  mlir::ArrayAttr &gangDimDeviceTypes) {
2721 
2722  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2723  gangDimDeviceTypeAttrs;
2724  bool needCommaBeforeOperands = false;
2725 
2726  // Gang keyword only
2727  if (failed(parser.parseOptionalLParen())) {
2728  gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2730  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2731  return success();
2732  }
2733 
2734  // Parse keyword only attributes
2735  if (succeeded(parser.parseOptionalLSquare())) {
2736  if (failed(parser.parseCommaSeparatedList([&]() {
2737  if (parser.parseAttribute(gangAttrs.emplace_back()))
2738  return failure();
2739  return success();
2740  })))
2741  return failure();
2742  if (parser.parseRSquare())
2743  return failure();
2744  needCommaBeforeOperands = true;
2745  }
2746 
2747  if (needCommaBeforeOperands && failed(parser.parseComma()))
2748  return failure();
2749 
2750  if (failed(parser.parseCommaSeparatedList([&]() {
2751  if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2752  parser.parseColon() ||
2753  parser.parseAttribute(gangDimAttrs.emplace_back()))
2754  return failure();
2755  if (succeeded(parser.parseOptionalLSquare())) {
2756  if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2757  parser.parseRSquare())
2758  return failure();
2759  } else {
2760  gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2761  parser.getContext(), mlir::acc::DeviceType::None));
2762  }
2763  return success();
2764  })))
2765  return failure();
2766 
2767  if (failed(parser.parseRParen()))
2768  return failure();
2769 
2770  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2771  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2772  gangDimDeviceTypes =
2773  ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2774 
2775  return success();
2776 }
2777 
2779  std::optional<mlir::ArrayAttr> gang,
2780  std::optional<mlir::ArrayAttr> gangDim,
2781  std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2782 
2783  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
2784  gang->size() == 1) {
2785  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2786  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2787  return;
2788  }
2789 
2790  p << "(";
2791 
2792  printDeviceTypes(p, gang);
2793 
2794  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
2795  p << ", ";
2796 
2797  if (hasDeviceTypeValues(gangDimDeviceTypes))
2798  llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2799  [&](const auto &pair) {
2800  p << acc::RoutineOp::getGangDimKeyword() << ": ";
2801  p << std::get<0>(pair);
2802  printSingleDeviceType(p, std::get<1>(pair));
2803  });
2804 
2805  p << ")";
2806 }
2807 
2808 static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2809  mlir::ArrayAttr &deviceTypes) {
2811  // Keyword only
2812  if (failed(parser.parseOptionalLParen())) {
2813  attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2815  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2816  return success();
2817  }
2818 
2819  // Parse device type attributes
2820  if (succeeded(parser.parseOptionalLSquare())) {
2821  if (failed(parser.parseCommaSeparatedList([&]() {
2822  if (parser.parseAttribute(attributes.emplace_back()))
2823  return failure();
2824  return success();
2825  })))
2826  return failure();
2827  if (parser.parseRSquare() || parser.parseRParen())
2828  return failure();
2829  }
2830  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2831  return success();
2832 }
2833 
2834 static void
2836  std::optional<mlir::ArrayAttr> deviceTypes) {
2837 
2838  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
2839  auto deviceTypeAttr =
2840  mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2841  if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2842  return;
2843  }
2844 
2845  if (!hasDeviceTypeValues(deviceTypes))
2846  return;
2847 
2848  p << "([";
2849  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2850  auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2851  p << dTypeAttr;
2852  });
2853  p << "])";
2854 }
2855 
2856 bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2857 
2858 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2859  return hasDeviceType(getWorker(), deviceType);
2860 }
2861 
2862 bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2863 
2864 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2865  return hasDeviceType(getVector(), deviceType);
2866 }
2867 
2868 bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2869 
2870 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2871  return hasDeviceType(getSeq(), deviceType);
2872 }
2873 
2874 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2875  return getBindNameValue(mlir::acc::DeviceType::None);
2876 }
2877 
2878 std::optional<llvm::StringRef>
2879 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2880  if (!hasDeviceTypeValues(getBindNameDeviceType()))
2881  return std::nullopt;
2882  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2883  auto attr = (*getBindName())[*pos];
2884  auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2885  return stringAttr.getValue();
2886  }
2887  return std::nullopt;
2888 }
2889 
2890 bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2891 
2892 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2893  return hasDeviceType(getGang(), deviceType);
2894 }
2895 
2896 std::optional<int64_t> RoutineOp::getGangDimValue() {
2897  return getGangDimValue(mlir::acc::DeviceType::None);
2898 }
2899 
2900 std::optional<int64_t>
2901 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2902  if (!hasDeviceTypeValues(getGangDimDeviceType()))
2903  return std::nullopt;
2904  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2905  auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2906  return intAttr.getInt();
2907  }
2908  return std::nullopt;
2909 }
2910 
2911 //===----------------------------------------------------------------------===//
2912 // InitOp
2913 //===----------------------------------------------------------------------===//
2914 
2915 LogicalResult acc::InitOp::verify() {
2916  Operation *currOp = *this;
2917  while ((currOp = currOp->getParentOp()))
2918  if (isComputeOperation(currOp))
2919  return emitOpError("cannot be nested in a compute operation");
2920  return success();
2921 }
2922 
2923 //===----------------------------------------------------------------------===//
2924 // ShutdownOp
2925 //===----------------------------------------------------------------------===//
2926 
2927 LogicalResult acc::ShutdownOp::verify() {
2928  Operation *currOp = *this;
2929  while ((currOp = currOp->getParentOp()))
2930  if (isComputeOperation(currOp))
2931  return emitOpError("cannot be nested in a compute operation");
2932  return success();
2933 }
2934 
2935 //===----------------------------------------------------------------------===//
2936 // SetOp
2937 //===----------------------------------------------------------------------===//
2938 
2939 LogicalResult acc::SetOp::verify() {
2940  Operation *currOp = *this;
2941  while ((currOp = currOp->getParentOp()))
2942  if (isComputeOperation(currOp))
2943  return emitOpError("cannot be nested in a compute operation");
2944  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2945  return emitOpError("at least one default_async, device_num, or device_type "
2946  "operand must appear");
2947  return success();
2948 }
2949 
2950 //===----------------------------------------------------------------------===//
2951 // UpdateOp
2952 //===----------------------------------------------------------------------===//
2953 
2954 LogicalResult acc::UpdateOp::verify() {
2955  // At least one of host or device should have a value.
2956  if (getDataClauseOperands().empty())
2957  return emitError("at least one value must be present in dataOperands");
2958 
2959  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2960  getAsyncOperandsDeviceTypeAttr(),
2961  "async")))
2962  return failure();
2963 
2965  *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2966  getWaitOperandsDeviceTypeAttr(), "wait")))
2967  return failure();
2968 
2969  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2970  return failure();
2971 
2972  for (mlir::Value operand : getDataClauseOperands())
2973  if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2974  operand.getDefiningOp()))
2975  return emitError("expect data entry/exit operation or acc.getdeviceptr "
2976  "as defining op");
2977 
2978  return success();
2979 }
2980 
2981 unsigned UpdateOp::getNumDataOperands() {
2982  return getDataClauseOperands().size();
2983 }
2984 
2985 Value UpdateOp::getDataOperand(unsigned i) {
2986  unsigned numOptional = getAsyncOperands().size();
2987  numOptional += getIfCond() ? 1 : 0;
2988  return getOperand(getWaitOperands().size() + numOptional + i);
2989 }
2990 
2991 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2992  MLIRContext *context) {
2993  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2994 }
2995 
2996 bool UpdateOp::hasAsyncOnly() {
2997  return hasAsyncOnly(mlir::acc::DeviceType::None);
2998 }
2999 
3000 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3001  return hasDeviceType(getAsync(), deviceType);
3002 }
3003 
3004 mlir::Value UpdateOp::getAsyncValue() {
3005  return getAsyncValue(mlir::acc::DeviceType::None);
3006 }
3007 
3008 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3010  return {};
3011 
3012  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3013  return getAsyncOperands()[*pos];
3014 
3015  return {};
3016 }
3017 
3018 bool UpdateOp::hasWaitOnly() {
3019  return hasWaitOnly(mlir::acc::DeviceType::None);
3020 }
3021 
3022 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3023  return hasDeviceType(getWaitOnly(), deviceType);
3024 }
3025 
3026 mlir::Operation::operand_range UpdateOp::getWaitValues() {
3027  return getWaitValues(mlir::acc::DeviceType::None);
3028 }
3029 
3031 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3033  getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3034  getHasWaitDevnum(), deviceType);
3035 }
3036 
3037 mlir::Value UpdateOp::getWaitDevnum() {
3038  return getWaitDevnum(mlir::acc::DeviceType::None);
3039 }
3040 
3041 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3042  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3043  getWaitOperandsSegments(), getHasWaitDevnum(),
3044  deviceType);
3045 }
3046 
3047 //===----------------------------------------------------------------------===//
3048 // WaitOp
3049 //===----------------------------------------------------------------------===//
3050 
3051 LogicalResult acc::WaitOp::verify() {
3052  // The async attribute represent the async clause without value. Therefore the
3053  // attribute and operand cannot appear at the same time.
3054  if (getAsyncOperand() && getAsync())
3055  return emitError("async attribute cannot appear with asyncOperand");
3056 
3057  if (getWaitDevnum() && getWaitOperands().empty())
3058  return emitError("wait_devnum cannot appear without waitOperands");
3059 
3060  return success();
3061 }
3062 
3063 #define GET_OP_CLASSES
3064 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3065 
3066 #define GET_ATTRDEF_CLASSES
3067 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3068 
3069 #define GET_TYPEDEF_CLASSES
3070 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3071 
3072 //===----------------------------------------------------------------------===//
3073 // acc dialect utilities
3074 //===----------------------------------------------------------------------===//
3075 
3078  auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3080  accDataClauseOp)
3081  .Case<ACC_DATA_ENTRY_OPS>(
3082  [&](auto entry) { return entry.getVarPtr(); })
3083  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3084  [&](auto exit) { return exit.getVarPtr(); })
3085  .Default([&](mlir::Operation *) {
3087  })};
3088  return varPtr;
3089 }
3090 
3092  auto varPtr{
3094  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3095  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3096  return varPtr;
3097 }
3098 
3100  auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3101  .Case<ACC_DATA_ENTRY_OPS>(
3102  [&](auto entry) { return entry.getVarType(); })
3103  .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3104  [&](auto exit) { return exit.getVarType(); })
3105  .Default([&](mlir::Operation *) { return mlir::Type(); })};
3106  return varType;
3107 }
3108 
3111  auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3113  accDataClauseOp)
3114  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3115  [&](auto dataClause) { return dataClause.getAccPtr(); })
3116  .Default([&](mlir::Operation *) {
3118  })};
3119  return accPtr;
3120 }
3121 
3123  auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3125  [&](auto dataClause) { return dataClause.getAccVar(); })
3126  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3127  return accPtr;
3128 }
3129 
3131  auto varPtrPtr{
3133  .Case<ACC_DATA_ENTRY_OPS>(
3134  [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3135  .Default([&](mlir::Operation *) { return mlir::Value(); })};
3136  return varPtrPtr;
3137 }
3138 
3143  accDataClauseOp)
3144  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3146  dataClause.getBounds().begin(), dataClause.getBounds().end());
3147  })
3148  .Default([&](mlir::Operation *) {
3150  })};
3151  return bounds;
3152 }
3153 
3157  accDataClauseOp)
3158  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3160  dataClause.getAsyncOperands().begin(),
3161  dataClause.getAsyncOperands().end());
3162  })
3163  .Default([&](mlir::Operation *) {
3165  });
3166 }
3167 
3168 mlir::ArrayAttr
3171  .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3172  return dataClause.getAsyncOperandsDeviceTypeAttr();
3173  })
3174  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3175 }
3176 
3177 mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3180  [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3181  .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3182 }
3183 
3184 std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3185  auto name{
3187  .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3188  .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3189  return {};
3190  })};
3191  return name;
3192 }
3193 
3194 std::optional<mlir::acc::DataClause>
3196  auto dataClause{
3198  accDataEntryOp)
3199  .Case<ACC_DATA_ENTRY_OPS>(
3200  [&](auto entry) { return entry.getDataClause(); })
3201  .Default([&](mlir::Operation *) { return std::nullopt; })};
3202  return dataClause;
3203 }
3204 
3206  auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3207  .Case<ACC_DATA_ENTRY_OPS>(
3208  [&](auto entry) { return entry.getImplicit(); })
3209  .Default([&](mlir::Operation *) { return false; })};
3210  return implicit;
3211 }
3212 
3214  auto dataOperands{
3217  [&](auto entry) { return entry.getDataClauseOperands(); })
3218  .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3219  return dataOperands;
3220 }
3221 
3224  auto dataOperands{
3227  [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3228  .Default([&](mlir::Operation *) { return nullptr; })};
3229  return dataOperands;
3230 }
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2220
@ None
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition: OpenACC.cpp:2778
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:646
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition: OpenACC.cpp:1997
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition: OpenACC.cpp:948
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition: OpenACC.cpp:2011
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:660
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition: OpenACC.cpp:1355
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition: OpenACC.cpp:275
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2681
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition: OpenACC.cpp:244
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition: OpenACC.cpp:1366
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition: OpenACC.cpp:1271
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition: OpenACC.cpp:80
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2835
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition: OpenACC.cpp:1806
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition: OpenACC.cpp:1520
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition: OpenACC.cpp:2566
static LogicalResult checkVarAndAccVar(Op op)
Definition: OpenACC.cpp:222
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:100
static LogicalResult checkVarAndVarType(Op op)
Definition: OpenACC.cpp:197
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition: OpenACC.cpp:2282
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition: OpenACC.cpp:877
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:1400
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:1030
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition: OpenACC.cpp:253
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:124
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1141
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition: OpenACC.cpp:229
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition: OpenACC.cpp:2313
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition: OpenACC.cpp:2808
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition: OpenACC.cpp:2717
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1254
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:1427
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
Definition: OpenACC.cpp:2707
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition: OpenACC.cpp:1208
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition: OpenACC.cpp:861
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:156
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition: OpenACC.cpp:316
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition: OpenACC.cpp:1825
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition: OpenACC.cpp:736
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition: OpenACC.cpp:1185
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:111
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition: OpenACC.cpp:892
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition: OpenACC.cpp:1500
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:86
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition: OpenACC.cpp:1952
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition: OpenACC.cpp:140
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition: OpenACC.cpp:1438
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition: OpenACC.cpp:287
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition: OpenACC.cpp:176
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition: OpenACC.cpp:958
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition: OpenACC.cpp:2648
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition: OpenACC.cpp:1191
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition: OpenACC.cpp:1546
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition: OpenACC.cpp:841
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition: OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition: OpenACC.h:44
#define ACC_DATA_EXIT_OPS
Definition: OpenACC.h:52
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition: Builders.h:216
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:832
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition: OpenACC.cpp:3122
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3091
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3110
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition: OpenACC.cpp:3195
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition: OpenACC.cpp:3223
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition: OpenACC.cpp:3140
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition: OpenACC.cpp:3213
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition: OpenACC.cpp:3184
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition: OpenACC.cpp:3205
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition: OpenACC.cpp:3155
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition: OpenACC.cpp:3130
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3177
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Definition: OpenACC.h:168
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition: OpenACC.cpp:3099
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition: OpenACC.cpp:3077
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition: OpenACC.cpp:3169
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.