MLIR 22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
14#include "mlir/IR/Builders.h"
18#include "mlir/IR/Matchers.h"
20#include "mlir/IR/SymbolTable.h"
21#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/LogicalResult.h"
26#include <variant>
27
28using namespace mlir;
29using namespace acc;
30
31#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36
37namespace {
38
39static bool isScalarLikeType(Type type) {
40 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
41}
42
43/// Helper function to attach the `VarName` attribute to an operation
44/// if a variable name is provided.
45static void attachVarNameAttr(Operation *op, OpBuilder &builder,
46 StringRef varName) {
47 if (!varName.empty()) {
48 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
49 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
50 }
51}
52
53template <typename T>
54struct MemRefPointerLikeModel
55 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
56 Type getElementType(Type pointer) const {
57 return cast<T>(pointer).getElementType();
58 }
59
60 mlir::acc::VariableTypeCategory
61 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
62 Type varType) const {
63 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
64 return mappableTy.getTypeCategory(varPtr);
65 }
66 auto memrefTy = cast<T>(pointer);
67 if (!memrefTy.hasRank()) {
68 // This memref is unranked - aka it could have any rank, including a
69 // rank of 0 which could mean scalar. For now, return uncategorized.
70 return mlir::acc::VariableTypeCategory::uncategorized;
71 }
72
73 if (memrefTy.getRank() == 0) {
74 if (isScalarLikeType(memrefTy.getElementType())) {
75 return mlir::acc::VariableTypeCategory::scalar;
76 }
77 // Zero-rank non-scalar - need further analysis to determine the type
78 // category. For now, return uncategorized.
79 return mlir::acc::VariableTypeCategory::uncategorized;
80 }
81
82 // It has a rank - must be an array.
83 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
84 return mlir::acc::VariableTypeCategory::array;
85 }
86
87 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
88 StringRef varName, Type varType, Value originalVar,
89 bool &needsFree) const {
90 auto memrefTy = cast<MemRefType>(pointer);
91
92 // Check if this is a static memref (all dimensions are known) - if yes
93 // then we can generate an alloca operation.
94 if (memrefTy.hasStaticShape()) {
95 needsFree = false; // alloca doesn't need deallocation
96 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
97 attachVarNameAttr(allocaOp, builder, varName);
98 return allocaOp.getResult();
99 }
100
101 // For dynamic memrefs, extract sizes from the original variable if
102 // provided. Otherwise they cannot be handled.
103 if (originalVar && originalVar.getType() == memrefTy &&
104 memrefTy.hasRank()) {
105 SmallVector<Value> dynamicSizes;
106 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
107 if (memrefTy.isDynamicDim(i)) {
108 // Extract the size of dimension i from the original variable
109 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
110 auto dimSize =
111 memref::DimOp::create(builder, loc, originalVar, indexValue);
112 dynamicSizes.push_back(dimSize);
113 }
114 // Note: We only add dynamic sizes to the dynamicSizes array
115 // Static dimensions are handled automatically by AllocOp
116 }
117 needsFree = true; // alloc needs deallocation
118 auto allocOp =
119 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
120 attachVarNameAttr(allocOp, builder, varName);
121 return allocOp.getResult();
122 }
123
124 // TODO: Unranked not yet supported.
125 return {};
126 }
127
128 bool genFree(Type pointer, OpBuilder &builder, Location loc,
129 TypedValue<PointerLikeType> varToFree, Value allocRes,
130 Type varType) const {
131 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
132 // Use allocRes if provided to determine the allocation type
133 Value valueToInspect = allocRes ? allocRes : memrefValue;
134
135 // Walk through casts to find the original allocation
136 Value currentValue = valueToInspect;
137 Operation *originalAlloc = nullptr;
138
139 // Follow the chain of operations to find the original allocation
140 // even if a casted result is provided.
141 while (currentValue) {
142 if (auto *definingOp = currentValue.getDefiningOp()) {
143 // Check if this is an allocation operation
144 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
145 originalAlloc = definingOp;
146 break;
147 }
148
149 // Check if this is a cast operation we can look through
150 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
151 currentValue = castOp.getSource();
152 continue;
153 }
154
155 // Check for other cast-like operations
156 if (auto reinterpretCastOp =
157 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
158 currentValue = reinterpretCastOp.getSource();
159 continue;
160 }
161
162 // If we can't look through this operation, stop
163 break;
164 }
165 // This is a block argument or similar - can't trace further.
166 break;
167 }
168
169 if (originalAlloc) {
170 if (isa<memref::AllocaOp>(originalAlloc)) {
171 // This is an alloca - no dealloc needed, but return true (success)
172 return true;
173 }
174 if (isa<memref::AllocOp>(originalAlloc)) {
175 // This is an alloc - generate dealloc on varToFree
176 memref::DeallocOp::create(builder, loc, memrefValue);
177 return true;
178 }
179 }
180 }
181
182 return false;
183 }
184
185 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
186 TypedValue<PointerLikeType> destination,
187 TypedValue<PointerLikeType> source, Type varType) const {
188 // Generate a copy operation between two memrefs
189 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
190 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
191
192 // As per memref documentation, source and destination must have same
193 // element type and shape in order to be compatible. We do not want to fail
194 // with an IR verification error - thus check that before generating the
195 // copy operation.
196 if (destMemref && srcMemref &&
197 destMemref.getType().getElementType() ==
198 srcMemref.getType().getElementType() &&
199 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
200 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
201 return true;
202 }
203
204 return false;
205 }
206};
207
208struct LLVMPointerPointerLikeModel
209 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
210 LLVM::LLVMPointerType> {
211 Type getElementType(Type pointer) const { return Type(); }
212};
213
214/// Helper function for any of the times we need to modify an ArrayAttr based on
215/// a device type list. Returns a new ArrayAttr with all of the
216/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
217/// list is empty).
218mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
219 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
220 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
222 if (existingDeviceTypes)
223 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
224
225 if (newDeviceTypes.empty())
226 deviceTypes.push_back(
227 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
228
229 for (DeviceType dt : newDeviceTypes)
230 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
231
232 return mlir::ArrayAttr::get(context, deviceTypes);
233}
234
235/// Helper function for any of the times we need to add operands that are
236/// affected by a device type list. Returns a new ArrayAttr with all of the
237/// existingDeviceTypes, plus the effective new ones (or an added none, if the
238/// new list is empty). Additionally, adds the arguments to the argCollection
239/// the correct number of times. This will also update a 'segments' array, even
240/// if it won't be used.
241mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
242 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
243 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
244 mlir::MutableOperandRange argCollection,
245 llvm::SmallVector<int32_t> &segments) {
247 if (existingDeviceTypes)
248 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
249
250 if (newDeviceTypes.empty()) {
251 argCollection.append(arguments);
252 segments.push_back(arguments.size());
253 deviceTypes.push_back(
254 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
255 }
256
257 for (DeviceType dt : newDeviceTypes) {
258 argCollection.append(arguments);
259 segments.push_back(arguments.size());
260 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
261 }
262
263 return mlir::ArrayAttr::get(context, deviceTypes);
264}
265
266/// Overload for when the 'segments' aren't needed.
267mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
268 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
269 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
270 mlir::MutableOperandRange argCollection) {
272 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
273 newDeviceTypes, arguments,
274 argCollection, segments);
275}
276} // namespace
277
278//===----------------------------------------------------------------------===//
279// OpenACC operations
280//===----------------------------------------------------------------------===//
281
282void OpenACCDialect::initialize() {
283 addOperations<
284#define GET_OP_LIST
285#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
286 >();
287 addAttributes<
288#define GET_ATTRDEF_LIST
289#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
290 >();
291 addTypes<
292#define GET_TYPEDEF_LIST
293#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
294 >();
295
296 // By attaching interfaces here, we make the OpenACC dialect dependent on
297 // the other dialects. This is probably better than having dialects like LLVM
298 // and memref be dependent on OpenACC.
299 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
300 *getContext());
301 UnrankedMemRefType::attachInterface<
302 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
303 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
304 *getContext());
305}
306
307//===----------------------------------------------------------------------===//
308// device_type support helpers
309//===----------------------------------------------------------------------===//
310
311static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
312 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
313}
314
315static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
316 mlir::acc::DeviceType deviceType) {
317 if (!hasDeviceTypeValues(arrayAttr))
318 return false;
319
320 for (auto attr : *arrayAttr) {
321 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
322 if (deviceTypeAttr.getValue() == deviceType)
323 return true;
324 }
325
326 return false;
327}
328
330 std::optional<mlir::ArrayAttr> deviceTypes) {
331 if (!hasDeviceTypeValues(deviceTypes))
332 return;
333
334 p << "[";
335 llvm::interleaveComma(*deviceTypes, p,
336 [&](mlir::Attribute attr) { p << attr; });
337 p << "]";
338}
339
340static std::optional<unsigned> findSegment(ArrayAttr segments,
341 mlir::acc::DeviceType deviceType) {
342 unsigned segmentIdx = 0;
343 for (auto attr : segments) {
344 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
345 if (deviceTypeAttr.getValue() == deviceType)
346 return std::make_optional(segmentIdx);
347 ++segmentIdx;
348 }
349 return std::nullopt;
350}
351
353getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
355 std::optional<llvm::ArrayRef<int32_t>> segments,
356 mlir::acc::DeviceType deviceType) {
357 if (!arrayAttr)
358 return range.take_front(0);
359 if (auto pos = findSegment(*arrayAttr, deviceType)) {
360 int32_t nbOperandsBefore = 0;
361 for (unsigned i = 0; i < *pos; ++i)
362 nbOperandsBefore += (*segments)[i];
363 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
364 }
365 return range.take_front(0);
366}
367
368static mlir::Value
369getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
371 std::optional<llvm::ArrayRef<int32_t>> segments,
372 std::optional<mlir::ArrayAttr> hasWaitDevnum,
373 mlir::acc::DeviceType deviceType) {
374 if (!hasDeviceTypeValues(deviceTypeAttr))
375 return {};
376 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
377 if (hasWaitDevnum->getValue()[*pos])
378 return getValuesFromSegments(deviceTypeAttr, operands, segments,
379 deviceType)
380 .front();
381 return {};
382}
383
385getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
387 std::optional<llvm::ArrayRef<int32_t>> segments,
388 std::optional<mlir::ArrayAttr> hasWaitDevnum,
389 mlir::acc::DeviceType deviceType) {
390 auto range =
391 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
392 if (range.empty())
393 return range;
394 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
395 if (hasWaitDevnum && *hasWaitDevnum) {
396 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
397 if (boolAttr.getValue())
398 return range.drop_front(1); // first value is devnum
399 }
400 }
401 return range;
402}
403
404template <typename Op>
405static LogicalResult checkWaitAndAsyncConflict(Op op) {
406 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
407 ++dtypeInt) {
408 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
409
410 // The asyncOnly attribute represent the async clause without value.
411 // Therefore the attribute and operand cannot appear at the same time.
412 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
413 op.hasAsyncOnly(dtype))
414 return op.emitError(
415 "asyncOnly attribute cannot appear with asyncOperand");
416
417 // The wait attribute represent the wait clause without values. Therefore
418 // the attribute and operands cannot appear at the same time.
419 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
420 op.hasWaitOnly(dtype))
421 return op.emitError("wait attribute cannot appear with waitOperands");
422 }
423 return success();
424}
425
426template <typename Op>
427static LogicalResult checkVarAndVarType(Op op) {
428 if (!op.getVar())
429 return op.emitError("must have var operand");
430
431 // A variable must have a type that is either pointer-like or mappable.
432 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
433 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
434 return op.emitError("var must be mappable or pointer-like");
435
436 // When it is a pointer-like type, the varType must capture the target type.
437 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
438 op.getVarType() == op.getVar().getType())
439 return op.emitError("varType must capture the element type of var");
440
441 return success();
442}
443
444template <typename Op>
445static LogicalResult checkVarAndAccVar(Op op) {
446 if (op.getVar().getType() != op.getAccVar().getType())
447 return op.emitError("input and output types must match");
448
449 return success();
450}
451
452template <typename Op>
453static LogicalResult checkNoModifier(Op op) {
454 if (op.getModifiers() != acc::DataClauseModifier::none)
455 return op.emitError("no data clause modifiers are allowed");
456 return success();
457}
458
459template <typename Op>
460static LogicalResult
461checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
462 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
463 return op.emitError(
464 "invalid data clause modifiers: " +
465 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
466
467 return success();
468}
469
470static ParseResult parseVar(mlir::OpAsmParser &parser,
472 // Either `var` or `varPtr` keyword is required.
473 if (failed(parser.parseOptionalKeyword("varPtr"))) {
474 if (failed(parser.parseKeyword("var")))
475 return failure();
476 }
477 if (failed(parser.parseLParen()))
478 return failure();
479 if (failed(parser.parseOperand(var)))
480 return failure();
481
482 return success();
483}
484
486 mlir::Value var) {
487 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
488 p << "varPtr(";
489 else
490 p << "var(";
491 p.printOperand(var);
492}
493
494static ParseResult parseAccVar(mlir::OpAsmParser &parser,
496 mlir::Type &accVarType) {
497 // Either `accVar` or `accPtr` keyword is required.
498 if (failed(parser.parseOptionalKeyword("accPtr"))) {
499 if (failed(parser.parseKeyword("accVar")))
500 return failure();
501 }
502 if (failed(parser.parseLParen()))
503 return failure();
504 if (failed(parser.parseOperand(var)))
505 return failure();
506 if (failed(parser.parseColon()))
507 return failure();
508 if (failed(parser.parseType(accVarType)))
509 return failure();
510 if (failed(parser.parseRParen()))
511 return failure();
512
513 return success();
514}
515
517 mlir::Value accVar, mlir::Type accVarType) {
518 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
519 p << "accPtr(";
520 else
521 p << "accVar(";
522 p.printOperand(accVar);
523 p << " : ";
524 p.printType(accVarType);
525 p << ")";
526}
527
528static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
529 mlir::Type &varPtrType,
530 mlir::TypeAttr &varTypeAttr) {
531 if (failed(parser.parseType(varPtrType)))
532 return failure();
533 if (failed(parser.parseRParen()))
534 return failure();
535
536 if (succeeded(parser.parseOptionalKeyword("varType"))) {
537 if (failed(parser.parseLParen()))
538 return failure();
539 mlir::Type varType;
540 if (failed(parser.parseType(varType)))
541 return failure();
542 varTypeAttr = mlir::TypeAttr::get(varType);
543 if (failed(parser.parseRParen()))
544 return failure();
545 } else {
546 // Set `varType` from the element type of the type of `varPtr`.
547 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
548 varTypeAttr = mlir::TypeAttr::get(
549 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
550 else
551 varTypeAttr = mlir::TypeAttr::get(varPtrType);
552 }
553
554 return success();
555}
556
558 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
559 p.printType(varPtrType);
560 p << ")";
561
562 // Print the `varType` only if it differs from the element type of
563 // `varPtr`'s type.
564 mlir::Type varType = varTypeAttr.getValue();
565 mlir::Type typeToCheckAgainst =
566 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
567 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
568 : varPtrType;
569 if (typeToCheckAgainst != varType) {
570 p << " varType(";
571 p.printType(varType);
572 p << ")";
573 }
574}
575
576//===----------------------------------------------------------------------===//
577// DataBoundsOp
578//===----------------------------------------------------------------------===//
579LogicalResult acc::DataBoundsOp::verify() {
580 auto extent = getExtent();
581 auto upperbound = getUpperbound();
582 if (!extent && !upperbound)
583 return emitError("expected extent or upperbound.");
584 return success();
585}
586
587//===----------------------------------------------------------------------===//
588// PrivateOp
589//===----------------------------------------------------------------------===//
590LogicalResult acc::PrivateOp::verify() {
591 if (getDataClause() != acc::DataClause::acc_private)
592 return emitError(
593 "data clause associated with private operation must match its intent");
594 if (failed(checkVarAndVarType(*this)))
595 return failure();
596 if (failed(checkNoModifier(*this)))
597 return failure();
598 return success();
599}
600
601//===----------------------------------------------------------------------===//
602// FirstprivateOp
603//===----------------------------------------------------------------------===//
604LogicalResult acc::FirstprivateOp::verify() {
605 if (getDataClause() != acc::DataClause::acc_firstprivate)
606 return emitError("data clause associated with firstprivate operation must "
607 "match its intent");
608 if (failed(checkVarAndVarType(*this)))
609 return failure();
610 if (failed(checkNoModifier(*this)))
611 return failure();
612 return success();
613}
614
615//===----------------------------------------------------------------------===//
616// FirstprivateMapInitialOp
617//===----------------------------------------------------------------------===//
618LogicalResult acc::FirstprivateMapInitialOp::verify() {
619 if (getDataClause() != acc::DataClause::acc_firstprivate)
620 return emitError("data clause associated with firstprivate operation must "
621 "match its intent");
622 if (failed(checkVarAndVarType(*this)))
623 return failure();
624 if (failed(checkNoModifier(*this)))
625 return failure();
626 return success();
627}
628
629//===----------------------------------------------------------------------===//
630// ReductionOp
631//===----------------------------------------------------------------------===//
632LogicalResult acc::ReductionOp::verify() {
633 if (getDataClause() != acc::DataClause::acc_reduction)
634 return emitError("data clause associated with reduction operation must "
635 "match its intent");
636 if (failed(checkVarAndVarType(*this)))
637 return failure();
638 if (failed(checkNoModifier(*this)))
639 return failure();
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// DevicePtrOp
645//===----------------------------------------------------------------------===//
646LogicalResult acc::DevicePtrOp::verify() {
647 if (getDataClause() != acc::DataClause::acc_deviceptr)
648 return emitError("data clause associated with deviceptr operation must "
649 "match its intent");
650 if (failed(checkVarAndVarType(*this)))
651 return failure();
652 if (failed(checkVarAndAccVar(*this)))
653 return failure();
654 if (failed(checkNoModifier(*this)))
655 return failure();
656 return success();
657}
658
659//===----------------------------------------------------------------------===//
660// PresentOp
661//===----------------------------------------------------------------------===//
662LogicalResult acc::PresentOp::verify() {
663 if (getDataClause() != acc::DataClause::acc_present)
664 return emitError(
665 "data clause associated with present operation must match its intent");
666 if (failed(checkVarAndVarType(*this)))
667 return failure();
668 if (failed(checkVarAndAccVar(*this)))
669 return failure();
670 if (failed(checkNoModifier(*this)))
671 return failure();
672 return success();
673}
674
675//===----------------------------------------------------------------------===//
676// CopyinOp
677//===----------------------------------------------------------------------===//
678LogicalResult acc::CopyinOp::verify() {
679 // Test for all clauses this operation can be decomposed from:
680 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
681 getDataClause() != acc::DataClause::acc_copyin_readonly &&
682 getDataClause() != acc::DataClause::acc_copy &&
683 getDataClause() != acc::DataClause::acc_reduction)
684 return emitError(
685 "data clause associated with copyin operation must match its intent"
686 " or specify original clause this operation was decomposed from");
687 if (failed(checkVarAndVarType(*this)))
688 return failure();
689 if (failed(checkVarAndAccVar(*this)))
690 return failure();
691 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
692 acc::DataClauseModifier::always |
693 acc::DataClauseModifier::capture)))
694 return failure();
695 return success();
696}
697
698bool acc::CopyinOp::isCopyinReadonly() {
699 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
700 acc::bitEnumContainsAny(getModifiers(),
701 acc::DataClauseModifier::readonly);
702}
703
704//===----------------------------------------------------------------------===//
705// CreateOp
706//===----------------------------------------------------------------------===//
707LogicalResult acc::CreateOp::verify() {
708 // Test for all clauses this operation can be decomposed from:
709 if (getDataClause() != acc::DataClause::acc_create &&
710 getDataClause() != acc::DataClause::acc_create_zero &&
711 getDataClause() != acc::DataClause::acc_copyout &&
712 getDataClause() != acc::DataClause::acc_copyout_zero)
713 return emitError(
714 "data clause associated with create operation must match its intent"
715 " or specify original clause this operation was decomposed from");
716 if (failed(checkVarAndVarType(*this)))
717 return failure();
718 if (failed(checkVarAndAccVar(*this)))
719 return failure();
720 // this op is the entry part of copyout, so it also needs to allow all
721 // modifiers allowed on copyout.
722 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
723 acc::DataClauseModifier::always |
724 acc::DataClauseModifier::capture)))
725 return failure();
726 return success();
727}
728
729bool acc::CreateOp::isCreateZero() {
730 // The zero modifier is encoded in the data clause.
731 return getDataClause() == acc::DataClause::acc_create_zero ||
732 getDataClause() == acc::DataClause::acc_copyout_zero ||
733 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
734}
735
736//===----------------------------------------------------------------------===//
737// NoCreateOp
738//===----------------------------------------------------------------------===//
739LogicalResult acc::NoCreateOp::verify() {
740 if (getDataClause() != acc::DataClause::acc_no_create)
741 return emitError("data clause associated with no_create operation must "
742 "match its intent");
743 if (failed(checkVarAndVarType(*this)))
744 return failure();
745 if (failed(checkVarAndAccVar(*this)))
746 return failure();
747 if (failed(checkNoModifier(*this)))
748 return failure();
749 return success();
750}
751
752//===----------------------------------------------------------------------===//
753// AttachOp
754//===----------------------------------------------------------------------===//
755LogicalResult acc::AttachOp::verify() {
756 if (getDataClause() != acc::DataClause::acc_attach)
757 return emitError(
758 "data clause associated with attach operation must match its intent");
759 if (failed(checkVarAndVarType(*this)))
760 return failure();
761 if (failed(checkVarAndAccVar(*this)))
762 return failure();
763 if (failed(checkNoModifier(*this)))
764 return failure();
765 return success();
766}
767
768//===----------------------------------------------------------------------===//
769// DeclareDeviceResidentOp
770//===----------------------------------------------------------------------===//
771
772LogicalResult acc::DeclareDeviceResidentOp::verify() {
773 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
774 return emitError("data clause associated with device_resident operation "
775 "must match its intent");
776 if (failed(checkVarAndVarType(*this)))
777 return failure();
778 if (failed(checkVarAndAccVar(*this)))
779 return failure();
780 if (failed(checkNoModifier(*this)))
781 return failure();
782 return success();
783}
784
785//===----------------------------------------------------------------------===//
786// DeclareLinkOp
787//===----------------------------------------------------------------------===//
788
789LogicalResult acc::DeclareLinkOp::verify() {
790 if (getDataClause() != acc::DataClause::acc_declare_link)
791 return emitError(
792 "data clause associated with link operation must match its intent");
793 if (failed(checkVarAndVarType(*this)))
794 return failure();
795 if (failed(checkVarAndAccVar(*this)))
796 return failure();
797 if (failed(checkNoModifier(*this)))
798 return failure();
799 return success();
800}
801
802//===----------------------------------------------------------------------===//
803// CopyoutOp
804//===----------------------------------------------------------------------===//
805LogicalResult acc::CopyoutOp::verify() {
806 // Test for all clauses this operation can be decomposed from:
807 if (getDataClause() != acc::DataClause::acc_copyout &&
808 getDataClause() != acc::DataClause::acc_copyout_zero &&
809 getDataClause() != acc::DataClause::acc_copy &&
810 getDataClause() != acc::DataClause::acc_reduction)
811 return emitError(
812 "data clause associated with copyout operation must match its intent"
813 " or specify original clause this operation was decomposed from");
814 if (!getVar() || !getAccVar())
815 return emitError("must have both host and device pointers");
816 if (failed(checkVarAndVarType(*this)))
817 return failure();
818 if (failed(checkVarAndAccVar(*this)))
819 return failure();
820 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
821 acc::DataClauseModifier::always |
822 acc::DataClauseModifier::capture)))
823 return failure();
824 return success();
825}
826
827bool acc::CopyoutOp::isCopyoutZero() {
828 return getDataClause() == acc::DataClause::acc_copyout_zero ||
829 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
830}
831
832//===----------------------------------------------------------------------===//
833// DeleteOp
834//===----------------------------------------------------------------------===//
835LogicalResult acc::DeleteOp::verify() {
836 // Test for all clauses this operation can be decomposed from:
837 if (getDataClause() != acc::DataClause::acc_delete &&
838 getDataClause() != acc::DataClause::acc_create &&
839 getDataClause() != acc::DataClause::acc_create_zero &&
840 getDataClause() != acc::DataClause::acc_copyin &&
841 getDataClause() != acc::DataClause::acc_copyin_readonly &&
842 getDataClause() != acc::DataClause::acc_present &&
843 getDataClause() != acc::DataClause::acc_no_create &&
844 getDataClause() != acc::DataClause::acc_declare_device_resident &&
845 getDataClause() != acc::DataClause::acc_declare_link)
846 return emitError(
847 "data clause associated with delete operation must match its intent"
848 " or specify original clause this operation was decomposed from");
849 if (!getAccVar())
850 return emitError("must have device pointer");
851 // This op is the exit part of copyin and create - thus allow all modifiers
852 // allowed on either case.
853 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
854 acc::DataClauseModifier::readonly |
855 acc::DataClauseModifier::always |
856 acc::DataClauseModifier::capture)))
857 return failure();
858 return success();
859}
860
861//===----------------------------------------------------------------------===//
862// DetachOp
863//===----------------------------------------------------------------------===//
864LogicalResult acc::DetachOp::verify() {
865 // Test for all clauses this operation can be decomposed from:
866 if (getDataClause() != acc::DataClause::acc_detach &&
867 getDataClause() != acc::DataClause::acc_attach)
868 return emitError(
869 "data clause associated with detach operation must match its intent"
870 " or specify original clause this operation was decomposed from");
871 if (!getAccVar())
872 return emitError("must have device pointer");
873 if (failed(checkNoModifier(*this)))
874 return failure();
875 return success();
876}
877
878//===----------------------------------------------------------------------===//
879// HostOp
880//===----------------------------------------------------------------------===//
881LogicalResult acc::UpdateHostOp::verify() {
882 // Test for all clauses this operation can be decomposed from:
883 if (getDataClause() != acc::DataClause::acc_update_host &&
884 getDataClause() != acc::DataClause::acc_update_self)
885 return emitError(
886 "data clause associated with host operation must match its intent"
887 " or specify original clause this operation was decomposed from");
888 if (!getVar() || !getAccVar())
889 return emitError("must have both host and device pointers");
890 if (failed(checkVarAndVarType(*this)))
891 return failure();
892 if (failed(checkVarAndAccVar(*this)))
893 return failure();
894 if (failed(checkNoModifier(*this)))
895 return failure();
896 return success();
897}
898
899//===----------------------------------------------------------------------===//
900// DeviceOp
901//===----------------------------------------------------------------------===//
902LogicalResult acc::UpdateDeviceOp::verify() {
903 // Test for all clauses this operation can be decomposed from:
904 if (getDataClause() != acc::DataClause::acc_update_device)
905 return emitError(
906 "data clause associated with device operation must match its intent"
907 " or specify original clause this operation was decomposed from");
908 if (failed(checkVarAndVarType(*this)))
909 return failure();
910 if (failed(checkVarAndAccVar(*this)))
911 return failure();
912 if (failed(checkNoModifier(*this)))
913 return failure();
914 return success();
915}
916
917//===----------------------------------------------------------------------===//
918// UseDeviceOp
919//===----------------------------------------------------------------------===//
920LogicalResult acc::UseDeviceOp::verify() {
921 // Test for all clauses this operation can be decomposed from:
922 if (getDataClause() != acc::DataClause::acc_use_device)
923 return emitError(
924 "data clause associated with use_device operation must match its intent"
925 " or specify original clause this operation was decomposed from");
926 if (failed(checkVarAndVarType(*this)))
927 return failure();
928 if (failed(checkVarAndAccVar(*this)))
929 return failure();
930 if (failed(checkNoModifier(*this)))
931 return failure();
932 return success();
933}
934
935//===----------------------------------------------------------------------===//
936// CacheOp
937//===----------------------------------------------------------------------===//
938LogicalResult acc::CacheOp::verify() {
939 // Test for all clauses this operation can be decomposed from:
940 if (getDataClause() != acc::DataClause::acc_cache &&
941 getDataClause() != acc::DataClause::acc_cache_readonly)
942 return emitError(
943 "data clause associated with cache operation must match its intent"
944 " or specify original clause this operation was decomposed from");
945 if (failed(checkVarAndVarType(*this)))
946 return failure();
947 if (failed(checkVarAndAccVar(*this)))
948 return failure();
949 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
950 return failure();
951 return success();
952}
953
954bool acc::CacheOp::isCacheReadonly() {
955 return getDataClause() == acc::DataClause::acc_cache_readonly ||
956 acc::bitEnumContainsAny(getModifiers(),
957 acc::DataClauseModifier::readonly);
958}
959
960template <typename StructureOp>
961static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
962 unsigned nRegions = 1) {
963
965 for (unsigned i = 0; i < nRegions; ++i)
966 regions.push_back(state.addRegion());
967
968 for (Region *region : regions)
969 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
970 return failure();
971
972 return success();
973}
974
976 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
977}
978
979namespace {
980/// Pattern to remove operation without region that have constant false `ifCond`
981/// and remove the condition from the operation if the `ifCond` is a true
982/// constant.
983template <typename OpTy>
984struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
985 using OpRewritePattern<OpTy>::OpRewritePattern;
986
987 LogicalResult matchAndRewrite(OpTy op,
988 PatternRewriter &rewriter) const override {
989 // Early return if there is no condition.
990 Value ifCond = op.getIfCond();
991 if (!ifCond)
992 return failure();
993
994 IntegerAttr constAttr;
995 if (!matchPattern(ifCond, m_Constant(&constAttr)))
996 return failure();
997 if (constAttr.getInt())
998 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
999 else
1000 rewriter.eraseOp(op);
1001
1002 return success();
1003 }
1004};
1005
1006/// Replaces the given op with the contents of the given single-block region,
1007/// using the operands of the block terminator to replace operation results.
1008static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1009 Region &region, ValueRange blockArgs = {}) {
1010 assert(region.hasOneBlock() && "expected single-block region");
1011 Block *block = &region.front();
1012 Operation *terminator = block->getTerminator();
1013 ValueRange results = terminator->getOperands();
1014 rewriter.inlineBlockBefore(block, op, blockArgs);
1015 rewriter.replaceOp(op, results);
1016 rewriter.eraseOp(terminator);
1017}
1018
1019/// Pattern to remove operation with region that have constant false `ifCond`
1020/// and remove the condition from the operation if the `ifCond` is constant
1021/// true.
1022template <typename OpTy>
1023struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1024 using OpRewritePattern<OpTy>::OpRewritePattern;
1025
1026 LogicalResult matchAndRewrite(OpTy op,
1027 PatternRewriter &rewriter) const override {
1028 // Early return if there is no condition.
1029 Value ifCond = op.getIfCond();
1030 if (!ifCond)
1031 return failure();
1032
1033 IntegerAttr constAttr;
1034 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1035 return failure();
1036 if (constAttr.getInt())
1037 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1038 else
1039 replaceOpWithRegion(rewriter, op, op.getRegion());
1040
1041 return success();
1042 }
1043};
1044
1045/// Remove empty acc.kernel_environment operations. If the operation has wait
1046/// operands, create a acc.wait operation to preserve synchronization.
1047struct RemoveEmptyKernelEnvironment
1048 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1049 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1050
1051 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1052 PatternRewriter &rewriter) const override {
1053 assert(op->getNumRegions() == 1 && "expected op to have one region");
1054
1055 Block &block = op.getRegion().front();
1056 if (!block.empty())
1057 return failure();
1058
1059 // Conservatively disable canonicalization of empty acc.kernel_environment
1060 // operations if the wait operands in the kernel_environment cannot be fully
1061 // represented by acc.wait operation.
1062
1063 // Disable canonicalization if device type is not the default
1064 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1065 for (auto attr : deviceTypeAttr) {
1066 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1067 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1068 return failure();
1069 }
1070 }
1071 }
1072
1073 // Disable canonicalization if any wait segment has a devnum
1074 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1075 for (auto attr : hasDevnumAttr) {
1076 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1077 if (boolAttr.getValue())
1078 return failure();
1079 }
1080 }
1081 }
1082
1083 // Disable canonicalization if there are multiple wait segments
1084 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1085 if (segmentsAttr.size() > 1)
1086 return failure();
1087 }
1088
1089 // Remove empty kernel environment.
1090 // Preserve synchronization by creating acc.wait operation if needed.
1091 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1092 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1093 /*asyncOperand=*/Value(),
1094 /*waitDevnum=*/Value(),
1095 /*async=*/nullptr,
1096 /*ifCond=*/Value());
1097 else
1098 rewriter.eraseOp(op);
1099
1100 return success();
1101 }
1102};
1103
1104//===----------------------------------------------------------------------===//
1105// Recipe Region Helpers
1106//===----------------------------------------------------------------------===//
1107
1108/// Create and populate an init region for privatization recipes.
1109/// Returns success if the region is populated, failure otherwise.
1110/// Sets needsFree to indicate if the allocated memory requires deallocation.
1111static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1112 Region &initRegion, Type varType,
1113 StringRef varName, ValueRange bounds,
1114 bool &needsFree) {
1115 // Create init block with arguments: original value + bounds
1116 SmallVector<Type> argTypes{varType};
1117 SmallVector<Location> argLocs{loc};
1118 for (Value bound : bounds) {
1119 argTypes.push_back(bound.getType());
1120 argLocs.push_back(loc);
1121 }
1122
1123 Block *initBlock = builder.createBlock(&initRegion);
1124 initBlock->addArguments(argTypes, argLocs);
1125 builder.setInsertionPointToStart(initBlock);
1126
1127 Value privatizedValue;
1128
1129 // Get the block argument that represents the original variable
1130 Value blockArgVar = initBlock->getArgument(0);
1131
1132 // Generate init region body based on variable type
1133 if (isa<MappableType>(varType)) {
1134 auto mappableTy = cast<MappableType>(varType);
1135 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1136 privatizedValue = mappableTy.generatePrivateInit(
1137 builder, loc, typedVar, varName, bounds, {}, needsFree);
1138 if (!privatizedValue)
1139 return failure();
1140 } else {
1141 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1142 auto pointerLikeTy = cast<PointerLikeType>(varType);
1143 // Use PointerLikeType's allocation API with the block argument
1144 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1145 blockArgVar, needsFree);
1146 if (!privatizedValue)
1147 return failure();
1148 }
1149
1150 // Add yield operation to init block
1151 acc::YieldOp::create(builder, loc, privatizedValue);
1152
1153 return success();
1154}
1155
1156/// Create and populate a copy region for firstprivate recipes.
1157/// Returns success if the region is populated, failure otherwise.
1158/// TODO: Handle MappableType - it does not yet have a copy API.
1159static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1160 Region &copyRegion, Type varType,
1161 ValueRange bounds) {
1162 // Create copy block with arguments: original value + privatized value +
1163 // bounds
1164 SmallVector<Type> copyArgTypes{varType, varType};
1165 SmallVector<Location> copyArgLocs{loc, loc};
1166 for (Value bound : bounds) {
1167 copyArgTypes.push_back(bound.getType());
1168 copyArgLocs.push_back(loc);
1169 }
1170
1171 Block *copyBlock = builder.createBlock(&copyRegion);
1172 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1173 builder.setInsertionPointToStart(copyBlock);
1174
1175 bool isMappable = isa<MappableType>(varType);
1176 bool isPointerLike = isa<PointerLikeType>(varType);
1177 // TODO: Handle MappableType - it does not yet have a copy API.
1178 // Otherwise, for now just fallback to pointer-like behavior.
1179 if (isMappable && !isPointerLike)
1180 return failure();
1181
1182 // Generate copy region body based on variable type
1183 if (isPointerLike) {
1184 auto pointerLikeTy = cast<PointerLikeType>(varType);
1185 Value originalArg = copyBlock->getArgument(0);
1186 Value privatizedArg = copyBlock->getArgument(1);
1187
1188 // Generate copy operation using PointerLikeType interface
1189 if (!pointerLikeTy.genCopy(
1190 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1191 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1192 return failure();
1193 }
1194
1195 // Add terminator to copy block
1196 acc::TerminatorOp::create(builder, loc);
1197
1198 return success();
1199}
1200
1201/// Create and populate a destroy region for privatization recipes.
1202/// Returns success if the region is populated, failure otherwise.
1203static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1204 Region &destroyRegion, Type varType,
1205 Value allocRes, ValueRange bounds) {
1206 // Create destroy block with arguments: original value + privatized value +
1207 // bounds
1208 SmallVector<Type> destroyArgTypes{varType, varType};
1209 SmallVector<Location> destroyArgLocs{loc, loc};
1210 for (Value bound : bounds) {
1211 destroyArgTypes.push_back(bound.getType());
1212 destroyArgLocs.push_back(loc);
1213 }
1214
1215 Block *destroyBlock = builder.createBlock(&destroyRegion);
1216 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1217 builder.setInsertionPointToStart(destroyBlock);
1218
1219 auto varToFree =
1220 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1221 if (isa<MappableType>(varType)) {
1222 auto mappableTy = cast<MappableType>(varType);
1223 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1224 return failure();
1225 } else {
1226 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1227 auto pointerLikeTy = cast<PointerLikeType>(varType);
1228 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1229 return failure();
1230 }
1231
1232 acc::TerminatorOp::create(builder, loc);
1233 return success();
1234}
1235
1236} // namespace
1237
1238//===----------------------------------------------------------------------===//
1239// PrivateRecipeOp
1240//===----------------------------------------------------------------------===//
1241
1243 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1244 Type type, bool verifyYield, bool optional = false) {
1245 if (optional && region.empty())
1246 return success();
1247
1248 if (region.empty())
1249 return op->emitOpError() << "expects non-empty " << regionName << " region";
1250 Block &firstBlock = region.front();
1251 if (firstBlock.getNumArguments() < 1 ||
1252 firstBlock.getArgument(0).getType() != type)
1253 return op->emitOpError() << "expects " << regionName
1254 << " region first "
1255 "argument of the "
1256 << regionType << " type";
1257
1258 if (verifyYield) {
1259 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1260 if (yieldOp.getOperands().size() != 1 ||
1261 yieldOp.getOperands().getTypes()[0] != type)
1262 return op->emitOpError() << "expects " << regionName
1263 << " region to "
1264 "yield a value of the "
1265 << regionType << " type";
1266 }
1267 }
1268 return success();
1269}
1270
1271LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1272 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1273 "privatization", "init", getType(),
1274 /*verifyYield=*/false)))
1275 return failure();
1277 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1278 /*verifyYield=*/false, /*optional=*/true)))
1279 return failure();
1280 return success();
1281}
1282
1283std::optional<PrivateRecipeOp>
1284PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1285 StringRef recipeName, Type varType,
1286 StringRef varName, ValueRange bounds) {
1287 // First, validate that we can handle this variable type
1288 bool isMappable = isa<MappableType>(varType);
1289 bool isPointerLike = isa<PointerLikeType>(varType);
1290
1291 // Unsupported type
1292 if (!isMappable && !isPointerLike)
1293 return std::nullopt;
1294
1295 OpBuilder::InsertionGuard guard(builder);
1296
1297 // Create the recipe operation first so regions have proper parent context
1298 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1299
1300 // Populate the init region
1301 bool needsFree = false;
1302 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1303 varName, bounds, needsFree))) {
1304 recipe.erase();
1305 return std::nullopt;
1306 }
1307
1308 // Only create destroy region if the allocation needs deallocation
1309 if (needsFree) {
1310 // Extract the allocated value from the init block's yield operation
1311 auto yieldOp =
1312 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1313 Value allocRes = yieldOp.getOperand(0);
1314
1315 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1316 varType, allocRes, bounds))) {
1317 recipe.erase();
1318 return std::nullopt;
1319 }
1320 }
1321
1322 return recipe;
1323}
1324
1325//===----------------------------------------------------------------------===//
1326// FirstprivateRecipeOp
1327//===----------------------------------------------------------------------===//
1328
1329LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1330 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1331 "privatization", "init", getType(),
1332 /*verifyYield=*/false)))
1333 return failure();
1334
1335 if (getCopyRegion().empty())
1336 return emitOpError() << "expects non-empty copy region";
1337
1338 Block &firstBlock = getCopyRegion().front();
1339 if (firstBlock.getNumArguments() < 2 ||
1340 firstBlock.getArgument(0).getType() != getType())
1341 return emitOpError() << "expects copy region with two arguments of the "
1342 "privatization type";
1343
1344 if (getDestroyRegion().empty())
1345 return success();
1346
1347 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1348 "privatization", "destroy",
1349 getType(), /*verifyYield=*/false)))
1350 return failure();
1351
1352 return success();
1353}
1354
1355std::optional<FirstprivateRecipeOp>
1356FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1357 StringRef recipeName, Type varType,
1358 StringRef varName, ValueRange bounds) {
1359 // First, validate that we can handle this variable type
1360 bool isMappable = isa<MappableType>(varType);
1361 bool isPointerLike = isa<PointerLikeType>(varType);
1362
1363 // Unsupported type
1364 if (!isMappable && !isPointerLike)
1365 return std::nullopt;
1366
1367 OpBuilder::InsertionGuard guard(builder);
1368
1369 // Create the recipe operation first so regions have proper parent context
1370 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1371
1372 // Populate the init region
1373 bool needsFree = false;
1374 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1375 varName, bounds, needsFree))) {
1376 recipe.erase();
1377 return std::nullopt;
1378 }
1379
1380 // Populate the copy region
1381 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1382 bounds))) {
1383 recipe.erase();
1384 return std::nullopt;
1385 }
1386
1387 // Only create destroy region if the allocation needs deallocation
1388 if (needsFree) {
1389 // Extract the allocated value from the init block's yield operation
1390 auto yieldOp =
1391 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1392 Value allocRes = yieldOp.getOperand(0);
1393
1394 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1395 varType, allocRes, bounds))) {
1396 recipe.erase();
1397 return std::nullopt;
1398 }
1399 }
1400
1401 return recipe;
1402}
1403
1404//===----------------------------------------------------------------------===//
1405// ReductionRecipeOp
1406//===----------------------------------------------------------------------===//
1407
1408LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1409 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1410 "init", getType(),
1411 /*verifyYield=*/false)))
1412 return failure();
1413
1414 if (getCombinerRegion().empty())
1415 return emitOpError() << "expects non-empty combiner region";
1416
1417 Block &reductionBlock = getCombinerRegion().front();
1418 if (reductionBlock.getNumArguments() < 2 ||
1419 reductionBlock.getArgument(0).getType() != getType() ||
1420 reductionBlock.getArgument(1).getType() != getType())
1421 return emitOpError() << "expects combiner region with the first two "
1422 << "arguments of the reduction type";
1423
1424 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1425 if (yieldOp.getOperands().size() != 1 ||
1426 yieldOp.getOperands().getTypes()[0] != getType())
1427 return emitOpError() << "expects combiner region to yield a value "
1428 "of the reduction type";
1429 }
1430
1431 return success();
1432}
1433
1434//===----------------------------------------------------------------------===//
1435// Custom parser and printer verifier for private clause
1436//===----------------------------------------------------------------------===//
1437
1438static ParseResult parseSymOperandList(
1439 mlir::OpAsmParser &parser,
1441 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
1443 if (failed(parser.parseCommaSeparatedList([&]() {
1444 if (parser.parseAttribute(attributes.emplace_back()) ||
1445 parser.parseArrow() ||
1446 parser.parseOperand(operands.emplace_back()) ||
1447 parser.parseColonType(types.emplace_back()))
1448 return failure();
1449 return success();
1450 })))
1451 return failure();
1452 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1453 attributes.end());
1454 symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
1455 return success();
1456}
1457
1459 mlir::OperandRange operands,
1460 mlir::TypeRange types,
1461 std::optional<mlir::ArrayAttr> attributes) {
1462 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
1463 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
1464 << std::get<1>(it).getType();
1465 });
1466}
1467
1468//===----------------------------------------------------------------------===//
1469// ParallelOp
1470//===----------------------------------------------------------------------===//
1471
1472/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1473template <typename Op>
1474static LogicalResult checkDataOperands(Op op,
1475 const mlir::ValueRange &operands) {
1476 for (mlir::Value operand : operands)
1477 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1478 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1479 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1480 operand.getDefiningOp()))
1481 return op.emitError(
1482 "expect data entry/exit operation or acc.getdeviceptr "
1483 "as defining op");
1484 return success();
1485}
1486
1487template <typename Op>
1488static LogicalResult
1489checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
1490 mlir::OperandRange operands, llvm::StringRef operandName,
1491 llvm::StringRef symbolName, bool checkOperandType = true) {
1492 if (!operands.empty()) {
1493 if (!attributes || attributes->size() != operands.size())
1494 return op->emitOpError()
1495 << "expected as many " << symbolName << " symbol reference as "
1496 << operandName << " operands";
1497 } else {
1498 if (attributes)
1499 return op->emitOpError()
1500 << "unexpected " << symbolName << " symbol reference";
1501 return success();
1502 }
1503
1505 for (auto args : llvm::zip(operands, *attributes)) {
1506 mlir::Value operand = std::get<0>(args);
1507
1508 if (!set.insert(operand).second)
1509 return op->emitOpError()
1510 << operandName << " operand appears more than once";
1511
1512 mlir::Type varType = operand.getType();
1513 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1514 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1515 if (!decl)
1516 return op->emitOpError()
1517 << "expected symbol reference " << symbolRef << " to point to a "
1518 << operandName << " declaration";
1519
1520 if (checkOperandType && decl.getType() && decl.getType() != varType)
1521 return op->emitOpError() << "expected " << operandName << " (" << varType
1522 << ") to be the same type as " << operandName
1523 << " declaration (" << decl.getType() << ")";
1524 }
1525
1526 return success();
1527}
1528
1529unsigned ParallelOp::getNumDataOperands() {
1530 return getReductionOperands().size() + getPrivateOperands().size() +
1531 getFirstprivateOperands().size() + getDataClauseOperands().size();
1532}
1533
1534Value ParallelOp::getDataOperand(unsigned i) {
1535 unsigned numOptional = getAsyncOperands().size();
1536 numOptional += getNumGangs().size();
1537 numOptional += getNumWorkers().size();
1538 numOptional += getVectorLength().size();
1539 numOptional += getIfCond() ? 1 : 0;
1540 numOptional += getSelfCond() ? 1 : 0;
1541 return getOperand(getWaitOperands().size() + numOptional + i);
1542}
1543
1544template <typename Op>
1545static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1546 ArrayAttr deviceTypes,
1547 llvm::StringRef keyword) {
1548 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1549 return op.emitOpError() << keyword << " operands count must match "
1550 << keyword << " device_type count";
1551 return success();
1552}
1553
1554template <typename Op>
1556 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1557 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1558 std::size_t numOperandsInSegments = 0;
1559 std::size_t nbOfSegments = 0;
1560
1561 if (segments) {
1562 for (auto segCount : segments.asArrayRef()) {
1563 if (maxInSegment != 0 && segCount > maxInSegment)
1564 return op.emitOpError() << keyword << " expects a maximum of "
1565 << maxInSegment << " values per segment";
1566 numOperandsInSegments += segCount;
1567 ++nbOfSegments;
1568 }
1569 }
1570
1571 if ((numOperandsInSegments != operands.size()) ||
1572 (!deviceTypes && !operands.empty()))
1573 return op.emitOpError()
1574 << keyword << " operand count does not match count in segments";
1575 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1576 return op.emitOpError()
1577 << keyword << " segment count does not match device_type count";
1578 return success();
1579}
1580
1581LogicalResult acc::ParallelOp::verify() {
1583 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1584 "privatizations", /*checkOperandType=*/false)))
1585 return failure();
1587 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1588 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1589 return failure();
1591 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1592 "reductions", false)))
1593 return failure();
1594
1596 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1597 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1598 return failure();
1599
1601 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1602 getWaitOperandsDeviceTypeAttr(), "wait")))
1603 return failure();
1604
1605 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1606 getNumWorkersDeviceTypeAttr(),
1607 "num_workers")))
1608 return failure();
1609
1610 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1611 getVectorLengthDeviceTypeAttr(),
1612 "vector_length")))
1613 return failure();
1614
1616 getAsyncOperandsDeviceTypeAttr(),
1617 "async")))
1618 return failure();
1619
1621 return failure();
1622
1623 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1624}
1625
1626static mlir::Value
1627getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1629 mlir::acc::DeviceType deviceType) {
1630 if (!arrayAttr)
1631 return {};
1632 if (auto pos = findSegment(*arrayAttr, deviceType))
1633 return range[*pos];
1634 return {};
1635}
1636
1637bool acc::ParallelOp::hasAsyncOnly() {
1638 return hasAsyncOnly(mlir::acc::DeviceType::None);
1639}
1640
1641bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1642 return hasDeviceType(getAsyncOnly(), deviceType);
1643}
1644
1645mlir::Value acc::ParallelOp::getAsyncValue() {
1646 return getAsyncValue(mlir::acc::DeviceType::None);
1647}
1648
1649mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1651 getAsyncOperands(), deviceType);
1652}
1653
1654mlir::Value acc::ParallelOp::getNumWorkersValue() {
1655 return getNumWorkersValue(mlir::acc::DeviceType::None);
1656}
1657
1659acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1660 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1661 deviceType);
1662}
1663
1664mlir::Value acc::ParallelOp::getVectorLengthValue() {
1665 return getVectorLengthValue(mlir::acc::DeviceType::None);
1666}
1667
1669acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1670 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1671 getVectorLength(), deviceType);
1672}
1673
1674mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1675 return getNumGangsValues(mlir::acc::DeviceType::None);
1676}
1677
1679ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1680 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1681 getNumGangsSegments(), deviceType);
1682}
1683
1684bool acc::ParallelOp::hasWaitOnly() {
1685 return hasWaitOnly(mlir::acc::DeviceType::None);
1686}
1687
1688bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1689 return hasDeviceType(getWaitOnly(), deviceType);
1690}
1691
1692mlir::Operation::operand_range ParallelOp::getWaitValues() {
1693 return getWaitValues(mlir::acc::DeviceType::None);
1694}
1695
1697ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1699 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1700 getHasWaitDevnum(), deviceType);
1701}
1702
1703mlir::Value ParallelOp::getWaitDevnum() {
1704 return getWaitDevnum(mlir::acc::DeviceType::None);
1705}
1706
1707mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1708 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1709 getWaitOperandsSegments(), getHasWaitDevnum(),
1710 deviceType);
1711}
1712
1713void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1714 mlir::OperationState &odsState,
1715 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1716 mlir::ValueRange vectorLength,
1717 mlir::ValueRange asyncOperands,
1718 mlir::ValueRange waitOperands, mlir::Value ifCond,
1719 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1720 mlir::ValueRange gangPrivateOperands,
1721 mlir::ValueRange gangFirstPrivateOperands,
1722 mlir::ValueRange dataClauseOperands) {
1723
1724 ParallelOp::build(
1725 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1726 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1727 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1728 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1729 /*numGangsDeviceType=*/nullptr, numWorkers,
1730 /*numWorkersDeviceType=*/nullptr, vectorLength,
1731 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1732 /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1733 gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1734 /*firstprivatizations=*/nullptr, dataClauseOperands,
1735 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1736}
1737
1738void acc::ParallelOp::addNumWorkersOperand(
1739 MLIRContext *context, mlir::Value newValue,
1740 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1741 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1742 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1743 getNumWorkersMutable()));
1744}
1745void acc::ParallelOp::addVectorLengthOperand(
1746 MLIRContext *context, mlir::Value newValue,
1747 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1748 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1749 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1750 getVectorLengthMutable()));
1751}
1752
1753void acc::ParallelOp::addAsyncOnly(
1754 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1755 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1756 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1757}
1758
1759void acc::ParallelOp::addAsyncOperand(
1760 MLIRContext *context, mlir::Value newValue,
1761 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1762 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1763 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1764 getAsyncOperandsMutable()));
1765}
1766
1767void acc::ParallelOp::addNumGangsOperands(
1768 MLIRContext *context, mlir::ValueRange newValues,
1769 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1771 if (getNumGangsSegments())
1772 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1773
1774 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1775 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1776 getNumGangsMutable(), segments));
1777
1778 setNumGangsSegments(segments);
1779}
1780void acc::ParallelOp::addWaitOnly(
1781 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1782 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1783 effectiveDeviceTypes));
1784}
1785void acc::ParallelOp::addWaitOperands(
1786 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1787 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1788
1790 if (getWaitOperandsSegments())
1791 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1792
1793 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1794 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1795 getWaitOperandsMutable(), segments));
1796 setWaitOperandsSegments(segments);
1797
1799 if (getHasWaitDevnumAttr())
1800 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1801 hasDevnums.insert(
1802 hasDevnums.end(),
1803 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1804 mlir::BoolAttr::get(context, hasDevnum));
1805 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1806}
1807
1808void acc::ParallelOp::addPrivatization(MLIRContext *context,
1809 mlir::acc::PrivateOp op,
1810 mlir::acc::PrivateRecipeOp recipe) {
1811 getPrivateOperandsMutable().append(op.getResult());
1812
1814
1815 if (getPrivatizationRecipesAttr())
1816 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1817
1818 recipes.push_back(
1819 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1820 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1821}
1822
1823void acc::ParallelOp::addFirstPrivatization(
1824 MLIRContext *context, mlir::acc::FirstprivateOp op,
1825 mlir::acc::FirstprivateRecipeOp recipe) {
1826 getFirstprivateOperandsMutable().append(op.getResult());
1827
1829
1830 if (getFirstprivatizationRecipesAttr())
1831 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1832
1833 recipes.push_back(
1834 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1835 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1836}
1837
1838void acc::ParallelOp::addReduction(MLIRContext *context,
1839 mlir::acc::ReductionOp op,
1840 mlir::acc::ReductionRecipeOp recipe) {
1841 getReductionOperandsMutable().append(op.getResult());
1842
1844
1845 if (getReductionRecipesAttr())
1846 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1847
1848 recipes.push_back(
1849 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1850 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1851}
1852
1853static ParseResult parseNumGangs(
1854 mlir::OpAsmParser &parser,
1856 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1857 mlir::DenseI32ArrayAttr &segments) {
1860
1861 do {
1862 if (failed(parser.parseLBrace()))
1863 return failure();
1864
1865 int32_t crtOperandsSize = operands.size();
1866 if (failed(parser.parseCommaSeparatedList(
1868 if (parser.parseOperand(operands.emplace_back()) ||
1869 parser.parseColonType(types.emplace_back()))
1870 return failure();
1871 return success();
1872 })))
1873 return failure();
1874 seg.push_back(operands.size() - crtOperandsSize);
1875
1876 if (failed(parser.parseRBrace()))
1877 return failure();
1878
1879 if (succeeded(parser.parseOptionalLSquare())) {
1880 if (parser.parseAttribute(attributes.emplace_back()) ||
1881 parser.parseRSquare())
1882 return failure();
1883 } else {
1884 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1885 parser.getContext(), mlir::acc::DeviceType::None));
1886 }
1887 } while (succeeded(parser.parseOptionalComma()));
1888
1889 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1890 attributes.end());
1891 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1892 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1893
1894 return success();
1895}
1896
1898 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1899 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1900 p << " [" << attr << "]";
1901}
1902
1904 mlir::OperandRange operands, mlir::TypeRange types,
1905 std::optional<mlir::ArrayAttr> deviceTypes,
1906 std::optional<mlir::DenseI32ArrayAttr> segments) {
1907 unsigned opIdx = 0;
1908 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1909 p << "{";
1910 llvm::interleaveComma(
1911 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1912 p << operands[opIdx] << " : " << operands[opIdx].getType();
1913 ++opIdx;
1914 });
1915 p << "}";
1916 printSingleDeviceType(p, it.value());
1917 });
1918}
1919
1921 mlir::OpAsmParser &parser,
1923 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1924 mlir::DenseI32ArrayAttr &segments) {
1927
1928 do {
1929 if (failed(parser.parseLBrace()))
1930 return failure();
1931
1932 int32_t crtOperandsSize = operands.size();
1933
1934 if (failed(parser.parseCommaSeparatedList(
1936 if (parser.parseOperand(operands.emplace_back()) ||
1937 parser.parseColonType(types.emplace_back()))
1938 return failure();
1939 return success();
1940 })))
1941 return failure();
1942
1943 seg.push_back(operands.size() - crtOperandsSize);
1944
1945 if (failed(parser.parseRBrace()))
1946 return failure();
1947
1948 if (succeeded(parser.parseOptionalLSquare())) {
1949 if (parser.parseAttribute(attributes.emplace_back()) ||
1950 parser.parseRSquare())
1951 return failure();
1952 } else {
1953 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1954 parser.getContext(), mlir::acc::DeviceType::None));
1955 }
1956 } while (succeeded(parser.parseOptionalComma()));
1957
1958 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1959 attributes.end());
1960 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1961 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1962
1963 return success();
1964}
1965
1968 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1969 std::optional<mlir::DenseI32ArrayAttr> segments) {
1970 unsigned opIdx = 0;
1971 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1972 p << "{";
1973 llvm::interleaveComma(
1974 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1975 p << operands[opIdx] << " : " << operands[opIdx].getType();
1976 ++opIdx;
1977 });
1978 p << "}";
1979 printSingleDeviceType(p, it.value());
1980 });
1981}
1982
1983static ParseResult parseWaitClause(
1984 mlir::OpAsmParser &parser,
1986 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1987 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1988 mlir::ArrayAttr &keywordOnly) {
1989 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1991
1992 bool needCommaBeforeOperands = false;
1993
1994 // Keyword only
1995 if (failed(parser.parseOptionalLParen())) {
1996 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1997 parser.getContext(), mlir::acc::DeviceType::None));
1998 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1999 return success();
2000 }
2001
2002 // Parse keyword only attributes
2003 if (succeeded(parser.parseOptionalLSquare())) {
2004 if (failed(parser.parseCommaSeparatedList([&]() {
2005 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2006 return failure();
2007 return success();
2008 })))
2009 return failure();
2010 if (parser.parseRSquare())
2011 return failure();
2012 needCommaBeforeOperands = true;
2013 }
2014
2015 if (needCommaBeforeOperands && failed(parser.parseComma()))
2016 return failure();
2017
2018 do {
2019 if (failed(parser.parseLBrace()))
2020 return failure();
2021
2022 int32_t crtOperandsSize = operands.size();
2023
2024 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2025 if (failed(parser.parseColon()))
2026 return failure();
2027 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2028 } else {
2029 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2030 }
2031
2032 if (failed(parser.parseCommaSeparatedList(
2034 if (parser.parseOperand(operands.emplace_back()) ||
2035 parser.parseColonType(types.emplace_back()))
2036 return failure();
2037 return success();
2038 })))
2039 return failure();
2040
2041 seg.push_back(operands.size() - crtOperandsSize);
2042
2043 if (failed(parser.parseRBrace()))
2044 return failure();
2045
2046 if (succeeded(parser.parseOptionalLSquare())) {
2047 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2048 parser.parseRSquare())
2049 return failure();
2050 } else {
2051 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2052 parser.getContext(), mlir::acc::DeviceType::None));
2053 }
2054 } while (succeeded(parser.parseOptionalComma()));
2055
2056 if (failed(parser.parseRParen()))
2057 return failure();
2058
2059 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2060 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2061 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2062 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2063
2064 return success();
2065}
2066
2067static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2068 if (!hasDeviceTypeValues(attrs))
2069 return false;
2070 if (attrs->size() != 1)
2071 return false;
2072 if (auto deviceTypeAttr =
2073 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2074 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2075 return false;
2076}
2077
2079 mlir::OperandRange operands, mlir::TypeRange types,
2080 std::optional<mlir::ArrayAttr> deviceTypes,
2081 std::optional<mlir::DenseI32ArrayAttr> segments,
2082 std::optional<mlir::ArrayAttr> hasDevNum,
2083 std::optional<mlir::ArrayAttr> keywordOnly) {
2084
2085 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2086 return;
2087
2088 p << "(";
2089
2090 printDeviceTypes(p, keywordOnly);
2091 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2092 p << ", ";
2093
2094 if (hasDeviceTypeValues(deviceTypes)) {
2095 unsigned opIdx = 0;
2096 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2097 p << "{";
2098 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2099 if (boolAttr && boolAttr.getValue())
2100 p << "devnum: ";
2101 llvm::interleaveComma(
2102 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2103 p << operands[opIdx] << " : " << operands[opIdx].getType();
2104 ++opIdx;
2105 });
2106 p << "}";
2107 printSingleDeviceType(p, it.value());
2108 });
2109 }
2110
2111 p << ")";
2112}
2113
2114static ParseResult parseDeviceTypeOperands(
2115 mlir::OpAsmParser &parser,
2117 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2119 if (failed(parser.parseCommaSeparatedList([&]() {
2120 if (parser.parseOperand(operands.emplace_back()) ||
2121 parser.parseColonType(types.emplace_back()))
2122 return failure();
2123 if (succeeded(parser.parseOptionalLSquare())) {
2124 if (parser.parseAttribute(attributes.emplace_back()) ||
2125 parser.parseRSquare())
2126 return failure();
2127 } else {
2128 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2129 parser.getContext(), mlir::acc::DeviceType::None));
2130 }
2131 return success();
2132 })))
2133 return failure();
2134 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2135 attributes.end());
2136 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2137 return success();
2138}
2139
2140static void
2142 mlir::OperandRange operands, mlir::TypeRange types,
2143 std::optional<mlir::ArrayAttr> deviceTypes) {
2144 if (!hasDeviceTypeValues(deviceTypes))
2145 return;
2146 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2147 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2148 printSingleDeviceType(p, std::get<0>(it));
2149 });
2150}
2151
2153 mlir::OpAsmParser &parser,
2155 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2156 mlir::ArrayAttr &keywordOnlyDeviceType) {
2157
2158 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2159 bool needCommaBeforeOperands = false;
2160
2161 if (failed(parser.parseOptionalLParen())) {
2162 // Keyword only
2163 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2164 parser.getContext(), mlir::acc::DeviceType::None));
2165 keywordOnlyDeviceType =
2166 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2167 return success();
2168 }
2169
2170 // Parse keyword only attributes
2171 if (succeeded(parser.parseOptionalLSquare())) {
2172 // Parse keyword only attributes
2173 if (failed(parser.parseCommaSeparatedList([&]() {
2174 if (parser.parseAttribute(
2175 keywordOnlyDeviceTypeAttributes.emplace_back()))
2176 return failure();
2177 return success();
2178 })))
2179 return failure();
2180 if (parser.parseRSquare())
2181 return failure();
2182 needCommaBeforeOperands = true;
2183 }
2184
2185 if (needCommaBeforeOperands && failed(parser.parseComma()))
2186 return failure();
2187
2189 if (failed(parser.parseCommaSeparatedList([&]() {
2190 if (parser.parseOperand(operands.emplace_back()) ||
2191 parser.parseColonType(types.emplace_back()))
2192 return failure();
2193 if (succeeded(parser.parseOptionalLSquare())) {
2194 if (parser.parseAttribute(attributes.emplace_back()) ||
2195 parser.parseRSquare())
2196 return failure();
2197 } else {
2198 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2199 parser.getContext(), mlir::acc::DeviceType::None));
2200 }
2201 return success();
2202 })))
2203 return failure();
2204
2205 if (failed(parser.parseRParen()))
2206 return failure();
2207
2208 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2209 attributes.end());
2210 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2211 return success();
2212}
2213
2216 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2217 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2218
2219 if (operands.begin() == operands.end() &&
2220 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2221 return;
2222 }
2223
2224 p << "(";
2225 printDeviceTypes(p, keywordOnlyDeviceTypes);
2226 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2227 hasDeviceTypeValues(deviceTypes))
2228 p << ", ";
2229 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2230 p << ")";
2231}
2232
2234 mlir::OpAsmParser &parser,
2235 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2236 mlir::Type &operandType, mlir::UnitAttr &attr) {
2237 // Keyword only
2238 if (failed(parser.parseOptionalLParen())) {
2239 attr = mlir::UnitAttr::get(parser.getContext());
2240 return success();
2241 }
2242
2244 if (failed(parser.parseOperand(op)))
2245 return failure();
2246 operand = op;
2247 if (failed(parser.parseColon()))
2248 return failure();
2249 if (failed(parser.parseType(operandType)))
2250 return failure();
2251 if (failed(parser.parseRParen()))
2252 return failure();
2253
2254 return success();
2255}
2256
2258 mlir::Operation *op,
2259 std::optional<mlir::Value> operand,
2260 mlir::Type operandType,
2261 mlir::UnitAttr attr) {
2262 if (attr)
2263 return;
2264
2265 p << "(";
2266 p.printOperand(*operand);
2267 p << " : ";
2268 p.printType(operandType);
2269 p << ")";
2270}
2271
2273 mlir::OpAsmParser &parser,
2275 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2276 // Keyword only
2277 if (failed(parser.parseOptionalLParen())) {
2278 attr = mlir::UnitAttr::get(parser.getContext());
2279 return success();
2280 }
2281
2282 if (failed(parser.parseCommaSeparatedList([&]() {
2283 if (parser.parseOperand(operands.emplace_back()))
2284 return failure();
2285 return success();
2286 })))
2287 return failure();
2288 if (failed(parser.parseColon()))
2289 return failure();
2290 if (failed(parser.parseCommaSeparatedList([&]() {
2291 if (parser.parseType(types.emplace_back()))
2292 return failure();
2293 return success();
2294 })))
2295 return failure();
2296 if (failed(parser.parseRParen()))
2297 return failure();
2298
2299 return success();
2300}
2301
2303 mlir::Operation *op,
2304 mlir::OperandRange operands,
2305 mlir::TypeRange types,
2306 mlir::UnitAttr attr) {
2307 if (attr)
2308 return;
2309
2310 p << "(";
2311 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2312 p << " : ";
2313 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2314 p << ")";
2315}
2316
2317static ParseResult
2319 mlir::acc::CombinedConstructsTypeAttr &attr) {
2320 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2321 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2322 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2323 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2324 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2325 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2326 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2327 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2328 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2329 } else {
2330 parser.emitError(parser.getCurrentLocation(),
2331 "expected compute construct name");
2332 return failure();
2333 }
2334 return success();
2335}
2336
2337static void
2339 mlir::acc::CombinedConstructsTypeAttr attr) {
2340 if (attr) {
2341 switch (attr.getValue()) {
2342 case mlir::acc::CombinedConstructsType::KernelsLoop:
2343 p << "kernels";
2344 break;
2345 case mlir::acc::CombinedConstructsType::ParallelLoop:
2346 p << "parallel";
2347 break;
2348 case mlir::acc::CombinedConstructsType::SerialLoop:
2349 p << "serial";
2350 break;
2351 };
2352 }
2353}
2354
2355//===----------------------------------------------------------------------===//
2356// SerialOp
2357//===----------------------------------------------------------------------===//
2358
2359unsigned SerialOp::getNumDataOperands() {
2360 return getReductionOperands().size() + getPrivateOperands().size() +
2361 getFirstprivateOperands().size() + getDataClauseOperands().size();
2362}
2363
2364Value SerialOp::getDataOperand(unsigned i) {
2365 unsigned numOptional = getAsyncOperands().size();
2366 numOptional += getIfCond() ? 1 : 0;
2367 numOptional += getSelfCond() ? 1 : 0;
2368 return getOperand(getWaitOperands().size() + numOptional + i);
2369}
2370
2371bool acc::SerialOp::hasAsyncOnly() {
2372 return hasAsyncOnly(mlir::acc::DeviceType::None);
2373}
2374
2375bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2376 return hasDeviceType(getAsyncOnly(), deviceType);
2377}
2378
2379mlir::Value acc::SerialOp::getAsyncValue() {
2380 return getAsyncValue(mlir::acc::DeviceType::None);
2381}
2382
2383mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2385 getAsyncOperands(), deviceType);
2386}
2387
2388bool acc::SerialOp::hasWaitOnly() {
2389 return hasWaitOnly(mlir::acc::DeviceType::None);
2390}
2391
2392bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2393 return hasDeviceType(getWaitOnly(), deviceType);
2394}
2395
2396mlir::Operation::operand_range SerialOp::getWaitValues() {
2397 return getWaitValues(mlir::acc::DeviceType::None);
2398}
2399
2401SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2403 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2404 getHasWaitDevnum(), deviceType);
2405}
2406
2407mlir::Value SerialOp::getWaitDevnum() {
2408 return getWaitDevnum(mlir::acc::DeviceType::None);
2409}
2410
2411mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2412 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2413 getWaitOperandsSegments(), getHasWaitDevnum(),
2414 deviceType);
2415}
2416
2417LogicalResult acc::SerialOp::verify() {
2419 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2420 "privatizations", /*checkOperandType=*/false)))
2421 return failure();
2423 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2424 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
2425 return failure();
2427 *this, getReductionRecipes(), getReductionOperands(), "reduction",
2428 "reductions", false)))
2429 return failure();
2430
2432 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2433 getWaitOperandsDeviceTypeAttr(), "wait")))
2434 return failure();
2435
2437 getAsyncOperandsDeviceTypeAttr(),
2438 "async")))
2439 return failure();
2440
2442 return failure();
2443
2444 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2445}
2446
2447void acc::SerialOp::addAsyncOnly(
2448 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2449 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2450 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2451}
2452
2453void acc::SerialOp::addAsyncOperand(
2454 MLIRContext *context, mlir::Value newValue,
2455 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2456 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2457 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2458 getAsyncOperandsMutable()));
2459}
2460
2461void acc::SerialOp::addWaitOnly(
2462 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2463 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2464 effectiveDeviceTypes));
2465}
2466void acc::SerialOp::addWaitOperands(
2467 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2468 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2469
2471 if (getWaitOperandsSegments())
2472 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2473
2474 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2475 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2476 getWaitOperandsMutable(), segments));
2477 setWaitOperandsSegments(segments);
2478
2480 if (getHasWaitDevnumAttr())
2481 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2482 hasDevnums.insert(
2483 hasDevnums.end(),
2484 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2485 mlir::BoolAttr::get(context, hasDevnum));
2486 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2487}
2488
2489void acc::SerialOp::addPrivatization(MLIRContext *context,
2490 mlir::acc::PrivateOp op,
2491 mlir::acc::PrivateRecipeOp recipe) {
2492 getPrivateOperandsMutable().append(op.getResult());
2493
2495
2496 if (getPrivatizationRecipesAttr())
2497 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2498
2499 recipes.push_back(
2500 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2501 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2502}
2503
2504void acc::SerialOp::addFirstPrivatization(
2505 MLIRContext *context, mlir::acc::FirstprivateOp op,
2506 mlir::acc::FirstprivateRecipeOp recipe) {
2507 getFirstprivateOperandsMutable().append(op.getResult());
2508
2510
2511 if (getFirstprivatizationRecipesAttr())
2512 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2513
2514 recipes.push_back(
2515 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2516 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2517}
2518
2519void acc::SerialOp::addReduction(MLIRContext *context,
2520 mlir::acc::ReductionOp op,
2521 mlir::acc::ReductionRecipeOp recipe) {
2522 getReductionOperandsMutable().append(op.getResult());
2523
2525
2526 if (getReductionRecipesAttr())
2527 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2528
2529 recipes.push_back(
2530 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2531 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2532}
2533
2534//===----------------------------------------------------------------------===//
2535// KernelsOp
2536//===----------------------------------------------------------------------===//
2537
2538unsigned KernelsOp::getNumDataOperands() {
2539 return getDataClauseOperands().size();
2540}
2541
2542Value KernelsOp::getDataOperand(unsigned i) {
2543 unsigned numOptional = getAsyncOperands().size();
2544 numOptional += getWaitOperands().size();
2545 numOptional += getNumGangs().size();
2546 numOptional += getNumWorkers().size();
2547 numOptional += getVectorLength().size();
2548 numOptional += getIfCond() ? 1 : 0;
2549 numOptional += getSelfCond() ? 1 : 0;
2550 return getOperand(numOptional + i);
2551}
2552
2553bool acc::KernelsOp::hasAsyncOnly() {
2554 return hasAsyncOnly(mlir::acc::DeviceType::None);
2555}
2556
2557bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2558 return hasDeviceType(getAsyncOnly(), deviceType);
2559}
2560
2561mlir::Value acc::KernelsOp::getAsyncValue() {
2562 return getAsyncValue(mlir::acc::DeviceType::None);
2563}
2564
2565mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2567 getAsyncOperands(), deviceType);
2568}
2569
2570mlir::Value acc::KernelsOp::getNumWorkersValue() {
2571 return getNumWorkersValue(mlir::acc::DeviceType::None);
2572}
2573
2575acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2576 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2577 deviceType);
2578}
2579
2580mlir::Value acc::KernelsOp::getVectorLengthValue() {
2581 return getVectorLengthValue(mlir::acc::DeviceType::None);
2582}
2583
2585acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2586 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2587 getVectorLength(), deviceType);
2588}
2589
2590mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2591 return getNumGangsValues(mlir::acc::DeviceType::None);
2592}
2593
2595KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2596 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2597 getNumGangsSegments(), deviceType);
2598}
2599
2600bool acc::KernelsOp::hasWaitOnly() {
2601 return hasWaitOnly(mlir::acc::DeviceType::None);
2602}
2603
2604bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2605 return hasDeviceType(getWaitOnly(), deviceType);
2606}
2607
2608mlir::Operation::operand_range KernelsOp::getWaitValues() {
2609 return getWaitValues(mlir::acc::DeviceType::None);
2610}
2611
2613KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2615 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2616 getHasWaitDevnum(), deviceType);
2617}
2618
2619mlir::Value KernelsOp::getWaitDevnum() {
2620 return getWaitDevnum(mlir::acc::DeviceType::None);
2621}
2622
2623mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2624 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2625 getWaitOperandsSegments(), getHasWaitDevnum(),
2626 deviceType);
2627}
2628
2629LogicalResult acc::KernelsOp::verify() {
2631 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2632 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2633 return failure();
2634
2636 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2637 getWaitOperandsDeviceTypeAttr(), "wait")))
2638 return failure();
2639
2640 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2641 getNumWorkersDeviceTypeAttr(),
2642 "num_workers")))
2643 return failure();
2644
2645 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2646 getVectorLengthDeviceTypeAttr(),
2647 "vector_length")))
2648 return failure();
2649
2651 getAsyncOperandsDeviceTypeAttr(),
2652 "async")))
2653 return failure();
2654
2656 return failure();
2657
2658 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2659}
2660
2661void acc::KernelsOp::addNumWorkersOperand(
2662 MLIRContext *context, mlir::Value newValue,
2663 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2664 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2665 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2666 getNumWorkersMutable()));
2667}
2668
2669void acc::KernelsOp::addVectorLengthOperand(
2670 MLIRContext *context, mlir::Value newValue,
2671 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2672 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2673 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2674 getVectorLengthMutable()));
2675}
2676void acc::KernelsOp::addAsyncOnly(
2677 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2678 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2679 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2680}
2681
2682void acc::KernelsOp::addAsyncOperand(
2683 MLIRContext *context, mlir::Value newValue,
2684 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2685 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2686 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2687 getAsyncOperandsMutable()));
2688}
2689
2690void acc::KernelsOp::addNumGangsOperands(
2691 MLIRContext *context, mlir::ValueRange newValues,
2692 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2694 if (getNumGangsSegmentsAttr())
2695 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2696
2697 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2698 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2699 getNumGangsMutable(), segments));
2700
2701 setNumGangsSegments(segments);
2702}
2703
2704void acc::KernelsOp::addWaitOnly(
2705 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2706 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2707 effectiveDeviceTypes));
2708}
2709void acc::KernelsOp::addWaitOperands(
2710 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2711 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2712
2714 if (getWaitOperandsSegments())
2715 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2716
2717 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2718 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2719 getWaitOperandsMutable(), segments));
2720 setWaitOperandsSegments(segments);
2721
2723 if (getHasWaitDevnumAttr())
2724 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2725 hasDevnums.insert(
2726 hasDevnums.end(),
2727 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2728 mlir::BoolAttr::get(context, hasDevnum));
2729 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2730}
2731
2732//===----------------------------------------------------------------------===//
2733// HostDataOp
2734//===----------------------------------------------------------------------===//
2735
2736LogicalResult acc::HostDataOp::verify() {
2737 if (getDataClauseOperands().empty())
2738 return emitError("at least one operand must appear on the host_data "
2739 "operation");
2740
2741 for (mlir::Value operand : getDataClauseOperands())
2742 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2743 return emitError("expect data entry operation as defining op");
2744 return success();
2745}
2746
2747void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2748 MLIRContext *context) {
2749 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2750}
2751
2752//===----------------------------------------------------------------------===//
2753// KernelEnvironmentOp
2754//===----------------------------------------------------------------------===//
2755
2756void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2757 RewritePatternSet &results, MLIRContext *context) {
2758 results.add<RemoveEmptyKernelEnvironment>(context);
2759}
2760
2761//===----------------------------------------------------------------------===//
2762// LoopOp
2763//===----------------------------------------------------------------------===//
2764
2765static ParseResult parseGangValue(
2766 OpAsmParser &parser, llvm::StringRef keyword,
2769 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2770 bool &needCommaBetweenValues, bool &newValue) {
2771 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2772 if (parser.parseEqual())
2773 return failure();
2774 if (parser.parseOperand(operands.emplace_back()) ||
2775 parser.parseColonType(types.emplace_back()))
2776 return failure();
2777 attributes.push_back(gangArgType);
2778 needCommaBetweenValues = true;
2779 newValue = true;
2780 }
2781 return success();
2782}
2783
2784static ParseResult parseGangClause(
2785 OpAsmParser &parser,
2787 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2788 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2789 mlir::ArrayAttr &gangOnlyDeviceType) {
2790 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2791 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2792 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2794 bool needCommaBetweenValues = false;
2795 bool needCommaBeforeOperands = false;
2796
2797 if (failed(parser.parseOptionalLParen())) {
2798 // Gang only keyword
2799 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2800 parser.getContext(), mlir::acc::DeviceType::None));
2801 gangOnlyDeviceType =
2802 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2803 return success();
2804 }
2805
2806 // Parse gang only attributes
2807 if (succeeded(parser.parseOptionalLSquare())) {
2808 // Parse gang only attributes
2809 if (failed(parser.parseCommaSeparatedList([&]() {
2810 if (parser.parseAttribute(
2811 gangOnlyDeviceTypeAttributes.emplace_back()))
2812 return failure();
2813 return success();
2814 })))
2815 return failure();
2816 if (parser.parseRSquare())
2817 return failure();
2818 needCommaBeforeOperands = true;
2819 }
2820
2821 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2822 mlir::acc::GangArgType::Num);
2823 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2824 mlir::acc::GangArgType::Dim);
2825 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2826 parser.getContext(), mlir::acc::GangArgType::Static);
2827
2828 do {
2829 if (needCommaBeforeOperands) {
2830 needCommaBeforeOperands = false;
2831 continue;
2832 }
2833
2834 if (failed(parser.parseLBrace()))
2835 return failure();
2836
2837 int32_t crtOperandsSize = gangOperands.size();
2838 while (true) {
2839 bool newValue = false;
2840 bool needValue = false;
2841 if (needCommaBetweenValues) {
2842 if (succeeded(parser.parseOptionalComma()))
2843 needValue = true; // expect a new value after comma.
2844 else
2845 break;
2846 }
2847
2848 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2849 gangOperands, gangOperandsType,
2850 gangArgTypeAttributes, argNum,
2851 needCommaBetweenValues, newValue)))
2852 return failure();
2853 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2854 gangOperands, gangOperandsType,
2855 gangArgTypeAttributes, argDim,
2856 needCommaBetweenValues, newValue)))
2857 return failure();
2858 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2859 gangOperands, gangOperandsType,
2860 gangArgTypeAttributes, argStatic,
2861 needCommaBetweenValues, newValue)))
2862 return failure();
2863
2864 if (!newValue && needValue) {
2865 parser.emitError(parser.getCurrentLocation(),
2866 "new value expected after comma");
2867 return failure();
2868 }
2869
2870 if (!newValue)
2871 break;
2872 }
2873
2874 if (gangOperands.empty())
2875 return parser.emitError(
2876 parser.getCurrentLocation(),
2877 "expect at least one of num, dim or static values");
2878
2879 if (failed(parser.parseRBrace()))
2880 return failure();
2881
2882 if (succeeded(parser.parseOptionalLSquare())) {
2883 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2884 parser.parseRSquare())
2885 return failure();
2886 } else {
2887 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2888 parser.getContext(), mlir::acc::DeviceType::None));
2889 }
2890
2891 seg.push_back(gangOperands.size() - crtOperandsSize);
2892
2893 } while (succeeded(parser.parseOptionalComma()));
2894
2895 if (failed(parser.parseRParen()))
2896 return failure();
2897
2898 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2899 gangArgTypeAttributes.end());
2900 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2901 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2902
2904 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2905 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2906
2907 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2908 return success();
2909}
2910
2912 mlir::OperandRange operands, mlir::TypeRange types,
2913 std::optional<mlir::ArrayAttr> gangArgTypes,
2914 std::optional<mlir::ArrayAttr> deviceTypes,
2915 std::optional<mlir::DenseI32ArrayAttr> segments,
2916 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2917
2918 if (operands.begin() == operands.end() &&
2919 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2920 return;
2921 }
2922
2923 p << "(";
2924
2925 printDeviceTypes(p, gangOnlyDeviceTypes);
2926
2927 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2928 hasDeviceTypeValues(deviceTypes))
2929 p << ", ";
2930
2931 if (hasDeviceTypeValues(deviceTypes)) {
2932 unsigned opIdx = 0;
2933 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2934 p << "{";
2935 llvm::interleaveComma(
2936 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2937 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2938 (*gangArgTypes)[opIdx]);
2939 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2940 p << LoopOp::getGangNumKeyword();
2941 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2942 p << LoopOp::getGangDimKeyword();
2943 else if (gangArgTypeAttr.getValue() ==
2944 mlir::acc::GangArgType::Static)
2945 p << LoopOp::getGangStaticKeyword();
2946 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2947 ++opIdx;
2948 });
2949 p << "}";
2950 printSingleDeviceType(p, it.value());
2951 });
2952 }
2953 p << ")";
2954}
2955
2957 std::optional<mlir::ArrayAttr> segments,
2958 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2959 if (!segments)
2960 return false;
2961 for (auto attr : *segments) {
2962 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2963 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2964 return true;
2965 }
2966 return false;
2967}
2968
2969/// Check for duplicates in the DeviceType array attribute.
2970LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2971 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2972 if (!deviceTypes)
2973 return success();
2974 for (auto attr : deviceTypes) {
2975 auto deviceTypeAttr =
2976 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2977 if (!deviceTypeAttr)
2978 return failure();
2979 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2980 return failure();
2981 }
2982 return success();
2983}
2984
2985LogicalResult acc::LoopOp::verify() {
2986 if (getUpperbound().size() != getStep().size())
2987 return emitError() << "number of upperbounds expected to be the same as "
2988 "number of steps";
2989
2990 if (getUpperbound().size() != getLowerbound().size())
2991 return emitError() << "number of upperbounds expected to be the same as "
2992 "number of lowerbounds";
2993
2994 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2995 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2996 return emitError() << "inclusiveUpperbound size is expected to be the same"
2997 << " as upperbound size";
2998
2999 // Check collapse
3000 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3001 return emitOpError() << "collapse device_type attr must be define when"
3002 << " collapse attr is present";
3003
3004 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3005 getCollapseAttr().getValue().size() !=
3006 getCollapseDeviceTypeAttr().getValue().size())
3007 return emitOpError() << "collapse attribute count must match collapse"
3008 << " device_type count";
3009 if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
3010 return emitOpError()
3011 << "duplicate device_type found in collapseDeviceType attribute";
3012
3013 // Check gang
3014 if (!getGangOperands().empty()) {
3015 if (!getGangOperandsArgType())
3016 return emitOpError() << "gangOperandsArgType attribute must be defined"
3017 << " when gang operands are present";
3018
3019 if (getGangOperands().size() !=
3020 getGangOperandsArgTypeAttr().getValue().size())
3021 return emitOpError() << "gangOperandsArgType attribute count must match"
3022 << " gangOperands count";
3023 }
3024 if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
3025 return emitOpError() << "duplicate device_type found in gang attribute";
3026
3028 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3029 getGangOperandsDeviceTypeAttr(), "gang")))
3030 return failure();
3031
3032 // Check worker
3033 if (failed(checkDeviceTypes(getWorkerAttr())))
3034 return emitOpError() << "duplicate device_type found in worker attribute";
3035 if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
3036 return emitOpError() << "duplicate device_type found in "
3037 "workerNumOperandsDeviceType attribute";
3038 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3039 getWorkerNumOperandsDeviceTypeAttr(),
3040 "worker")))
3041 return failure();
3042
3043 // Check vector
3044 if (failed(checkDeviceTypes(getVectorAttr())))
3045 return emitOpError() << "duplicate device_type found in vector attribute";
3046 if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
3047 return emitOpError() << "duplicate device_type found in "
3048 "vectorOperandsDeviceType attribute";
3049 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3050 getVectorOperandsDeviceTypeAttr(),
3051 "vector")))
3052 return failure();
3053
3055 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3056 getTileOperandsDeviceTypeAttr(), "tile")))
3057 return failure();
3058
3059 // auto, independent and seq attribute are mutually exclusive.
3060 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3061 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3062 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3063 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3064 return emitError() << "only one of auto, independent, seq can be present "
3065 "at the same time";
3066 }
3067
3068 // Check that at least one of auto, independent, or seq is present
3069 // for the device-independent default clauses.
3070 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3071 return attr.getValue() == mlir::acc::DeviceType::None;
3072 };
3073 bool hasDefaultSeq =
3074 getSeqAttr()
3075 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3076 hasDeviceNone)
3077 : false;
3078 bool hasDefaultIndependent =
3079 getIndependentAttr()
3080 ? llvm::any_of(
3081 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3082 hasDeviceNone)
3083 : false;
3084 bool hasDefaultAuto =
3085 getAuto_Attr()
3086 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3087 hasDeviceNone)
3088 : false;
3089 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3090 return emitError()
3091 << "at least one of auto, independent, seq must be present";
3092 }
3093
3094 // Gang, worker and vector are incompatible with seq.
3095 if (getSeqAttr()) {
3096 for (auto attr : getSeqAttr()) {
3097 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3098 if (hasVector(deviceTypeAttr.getValue()) ||
3099 getVectorValue(deviceTypeAttr.getValue()) ||
3100 hasWorker(deviceTypeAttr.getValue()) ||
3101 getWorkerValue(deviceTypeAttr.getValue()) ||
3102 hasGang(deviceTypeAttr.getValue()) ||
3103 getGangValue(mlir::acc::GangArgType::Num,
3104 deviceTypeAttr.getValue()) ||
3105 getGangValue(mlir::acc::GangArgType::Dim,
3106 deviceTypeAttr.getValue()) ||
3107 getGangValue(mlir::acc::GangArgType::Static,
3108 deviceTypeAttr.getValue()))
3109 return emitError() << "gang, worker or vector cannot appear with seq";
3110 }
3111 }
3112
3114 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
3115 "privatizations", false)))
3116 return failure();
3117
3119 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3120 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
3121 return failure();
3122
3124 *this, getReductionRecipes(), getReductionOperands(), "reduction",
3125 "reductions", false)))
3126 return failure();
3127
3128 if (getCombined().has_value() &&
3129 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3130 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3131 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3132 return emitError("unexpected combined constructs attribute");
3133 }
3134
3135 // Check non-empty body().
3136 if (getRegion().empty())
3137 return emitError("expected non-empty body.");
3138
3139 if (getUnstructured()) {
3140 if (!isContainerLike())
3141 return emitError(
3142 "unstructured acc.loop must not have induction variables");
3143 } else if (isContainerLike()) {
3144 // When it is container-like - it is expected to hold a loop-like operation.
3145 // Obtain the maximum collapse count - we use this to check that there
3146 // are enough loops contained.
3147 uint64_t collapseCount = getCollapseValue().value_or(1);
3148 if (getCollapseAttr()) {
3149 for (auto collapseEntry : getCollapseAttr()) {
3150 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3151 if (intAttr.getValue().getZExtValue() > collapseCount)
3152 collapseCount = intAttr.getValue().getZExtValue();
3153 }
3154 }
3155
3156 // We want to check that we find enough loop-like operations inside.
3157 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3158 // level.
3159 mlir::Operation *expectedParent = this->getOperation();
3160 bool foundSibling = false;
3161 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3162 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3163 // This effectively checks that we are not looking at a sibling loop.
3164 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3165 expectedParent) {
3166 foundSibling = true;
3168 }
3169
3170 collapseCount--;
3171 expectedParent = op;
3172 }
3173 // We found enough contained loops.
3174 if (collapseCount == 0)
3177 });
3178
3179 if (foundSibling)
3180 return emitError("found sibling loops inside container-like acc.loop");
3181 if (collapseCount != 0)
3182 return emitError("failed to find enough loop-like operations inside "
3183 "container-like acc.loop");
3184 }
3185
3186 return success();
3187}
3188
3189unsigned LoopOp::getNumDataOperands() {
3190 return getReductionOperands().size() + getPrivateOperands().size() +
3191 getFirstprivateOperands().size();
3192}
3193
3194Value LoopOp::getDataOperand(unsigned i) {
3195 unsigned numOptional =
3196 getLowerbound().size() + getUpperbound().size() + getStep().size();
3197 numOptional += getGangOperands().size();
3198 numOptional += getVectorOperands().size();
3199 numOptional += getWorkerNumOperands().size();
3200 numOptional += getTileOperands().size();
3201 numOptional += getCacheOperands().size();
3202 return getOperand(numOptional + i);
3203}
3204
3205bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3206
3207bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3208 return hasDeviceType(getAuto_(), deviceType);
3209}
3210
3211bool LoopOp::hasIndependent() {
3212 return hasIndependent(mlir::acc::DeviceType::None);
3213}
3214
3215bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3216 return hasDeviceType(getIndependent(), deviceType);
3217}
3218
3219bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3220
3221bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3222 return hasDeviceType(getSeq(), deviceType);
3223}
3224
3225mlir::Value LoopOp::getVectorValue() {
3226 return getVectorValue(mlir::acc::DeviceType::None);
3227}
3228
3229mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3230 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3231 getVectorOperands(), deviceType);
3232}
3233
3234bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3235
3236bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3237 return hasDeviceType(getVector(), deviceType);
3238}
3239
3240mlir::Value LoopOp::getWorkerValue() {
3241 return getWorkerValue(mlir::acc::DeviceType::None);
3242}
3243
3244mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3245 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3246 getWorkerNumOperands(), deviceType);
3247}
3248
3249bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3250
3251bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3252 return hasDeviceType(getWorker(), deviceType);
3253}
3254
3255mlir::Operation::operand_range LoopOp::getTileValues() {
3256 return getTileValues(mlir::acc::DeviceType::None);
3257}
3258
3260LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3261 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3262 getTileOperandsSegments(), deviceType);
3263}
3264
3265std::optional<int64_t> LoopOp::getCollapseValue() {
3266 return getCollapseValue(mlir::acc::DeviceType::None);
3267}
3268
3269std::optional<int64_t>
3270LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3271 if (!getCollapseAttr())
3272 return std::nullopt;
3273 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3274 auto intAttr =
3275 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3276 return intAttr.getValue().getZExtValue();
3277 }
3278 return std::nullopt;
3279}
3280
3281mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3282 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3283}
3284
3285mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3286 mlir::acc::DeviceType deviceType) {
3287 if (getGangOperands().empty())
3288 return {};
3289 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3290 int32_t nbOperandsBefore = 0;
3291 for (unsigned i = 0; i < *pos; ++i)
3292 nbOperandsBefore += (*getGangOperandsSegments())[i];
3294 getGangOperands()
3295 .drop_front(nbOperandsBefore)
3296 .take_front((*getGangOperandsSegments())[*pos]);
3297
3298 int32_t argTypeIdx = nbOperandsBefore;
3299 for (auto value : values) {
3300 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3301 (*getGangOperandsArgType())[argTypeIdx]);
3302 if (gangArgTypeAttr.getValue() == gangArgType)
3303 return value;
3304 ++argTypeIdx;
3305 }
3306 }
3307 return {};
3308}
3309
3310bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3311
3312bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3313 return hasDeviceType(getGang(), deviceType);
3314}
3315
3316llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3317 return {&getRegion()};
3318}
3319
3320/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3321/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3322/// `(` ssa-id-and-type-list `)`
3323/// region
3324ParseResult
3327 SmallVectorImpl<Type> &lowerboundType,
3329 SmallVectorImpl<Type> &upperboundType,
3331 SmallVectorImpl<Type> &stepType) {
3332
3334 if (succeeded(
3335 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3336 if (parser.parseLParen() ||
3337 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3338 /*allowType=*/true) ||
3339 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3340 parser.parseOperandList(lowerbound, inductionVars.size(),
3342 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3343 parser.parseKeyword("to") || parser.parseLParen() ||
3344 parser.parseOperandList(upperbound, inductionVars.size(),
3346 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3347 parser.parseKeyword("step") || parser.parseLParen() ||
3348 parser.parseOperandList(step, inductionVars.size(),
3350 parser.parseColonTypeList(stepType) || parser.parseRParen())
3351 return failure();
3352 }
3353 return parser.parseRegion(region, inductionVars);
3354}
3355
3357 ValueRange lowerbound, TypeRange lowerboundType,
3358 ValueRange upperbound, TypeRange upperboundType,
3359 ValueRange steps, TypeRange stepType) {
3360 ValueRange regionArgs = region.front().getArguments();
3361 if (!regionArgs.empty()) {
3362 p << acc::LoopOp::getControlKeyword() << "(";
3363 llvm::interleaveComma(regionArgs, p,
3364 [&p](Value v) { p << v << " : " << v.getType(); });
3365 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3366 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3367 << " : " << stepType << ") ";
3368 }
3369 p.printRegion(region, /*printEntryBlockArgs=*/false);
3370}
3371
3372void acc::LoopOp::addSeq(MLIRContext *context,
3373 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3374 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3375 effectiveDeviceTypes));
3376}
3377
3378void acc::LoopOp::addIndependent(
3379 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3380 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3381 context, getIndependentAttr(), effectiveDeviceTypes));
3382}
3383
3384void acc::LoopOp::addAuto(MLIRContext *context,
3385 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3386 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3387 effectiveDeviceTypes));
3388}
3389
3390void acc::LoopOp::setCollapseForDeviceTypes(
3391 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3392 llvm::APInt value) {
3395
3396 assert((getCollapseAttr() == nullptr) ==
3397 (getCollapseDeviceTypeAttr() == nullptr));
3398 assert(value.getBitWidth() == 64);
3399
3400 if (getCollapseAttr()) {
3401 for (const auto &existing :
3402 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3403 newValues.push_back(std::get<0>(existing));
3404 newDeviceTypes.push_back(std::get<1>(existing));
3405 }
3406 }
3407
3408 if (effectiveDeviceTypes.empty()) {
3409 // If the effective device-types list is empty, this is before there are any
3410 // being applied by device_type, so this should be added as a 'none'.
3411 newValues.push_back(
3412 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3413 newDeviceTypes.push_back(
3414 acc::DeviceTypeAttr::get(context, DeviceType::None));
3415 } else {
3416 for (DeviceType dt : effectiveDeviceTypes) {
3417 newValues.push_back(
3418 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3419 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3420 }
3421 }
3422
3423 setCollapseAttr(ArrayAttr::get(context, newValues));
3424 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3425}
3426
3427void acc::LoopOp::setTileForDeviceTypes(
3428 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3429 ValueRange values) {
3431 if (getTileOperandsSegments())
3432 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3433
3434 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3435 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3436 getTileOperandsMutable(), segments));
3437
3438 setTileOperandsSegments(segments);
3439}
3440
3441void acc::LoopOp::addVectorOperand(
3442 MLIRContext *context, mlir::Value newValue,
3443 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3444 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3445 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3446 newValue, getVectorOperandsMutable()));
3447}
3448
3449void acc::LoopOp::addEmptyVector(
3450 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3451 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3452 effectiveDeviceTypes));
3453}
3454
3455void acc::LoopOp::addWorkerNumOperand(
3456 MLIRContext *context, mlir::Value newValue,
3457 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3458 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3459 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3460 newValue, getWorkerNumOperandsMutable()));
3461}
3462
3463void acc::LoopOp::addEmptyWorker(
3464 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3465 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3466 effectiveDeviceTypes));
3467}
3468
3469void acc::LoopOp::addEmptyGang(
3470 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3471 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3472 effectiveDeviceTypes));
3473}
3474
3475bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3476 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3477 return attr.getValue() == dt;
3478 };
3479 auto testFromArr = [=](ArrayAttr arr) -> bool {
3480 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3481 };
3482
3483 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3484 return true;
3485 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3486 return true;
3487 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3488 return true;
3489
3490 return false;
3491}
3492
3493bool acc::LoopOp::hasDefaultGangWorkerVector() {
3494 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3495 hasGang() || getGangValue(GangArgType::Num) ||
3496 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3497}
3498
3499acc::LoopParMode
3500acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3501 if (hasSeq(deviceType))
3502 return LoopParMode::loop_seq;
3503 if (hasAuto(deviceType))
3504 return LoopParMode::loop_auto;
3505 if (hasIndependent(deviceType))
3506 return LoopParMode::loop_independent;
3507 if (hasSeq())
3508 return LoopParMode::loop_seq;
3509 if (hasAuto())
3510 return LoopParMode::loop_auto;
3511 assert(hasIndependent() &&
3512 "loop must have default auto, seq, or independent");
3513 return LoopParMode::loop_independent;
3514}
3515
3516void acc::LoopOp::addGangOperands(
3517 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3520 if (std::optional<ArrayRef<int32_t>> existingSegments =
3521 getGangOperandsSegments())
3522 llvm::copy(*existingSegments, std::back_inserter(segments));
3523
3524 unsigned beforeCount = segments.size();
3525
3526 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3527 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3528 getGangOperandsMutable(), segments));
3529
3530 setGangOperandsSegments(segments);
3531
3532 // This is a bit of extra work to make sure we update the 'types' correctly by
3533 // adding to the types collection the correct number of times. We could
3534 // potentially add something similar to the
3535 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3536 // excessive for a one-off case.
3537 unsigned numAdded = segments.size() - beforeCount;
3538
3539 if (numAdded > 0) {
3541 if (getGangOperandsArgTypeAttr())
3542 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3543
3544 for (auto i : llvm::index_range(0u, numAdded)) {
3545 llvm::transform(argTypes, std::back_inserter(gangTypes),
3546 [=](mlir::acc::GangArgType gangTy) {
3547 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3548 });
3549 (void)i;
3550 }
3551
3552 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3553 }
3554}
3555
3556void acc::LoopOp::addPrivatization(MLIRContext *context,
3557 mlir::acc::PrivateOp op,
3558 mlir::acc::PrivateRecipeOp recipe) {
3559 getPrivateOperandsMutable().append(op.getResult());
3560
3562
3563 if (getPrivatizationRecipesAttr())
3564 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3565
3566 recipes.push_back(
3567 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3568 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3569}
3570
3571void acc::LoopOp::addFirstPrivatization(
3572 MLIRContext *context, mlir::acc::FirstprivateOp op,
3573 mlir::acc::FirstprivateRecipeOp recipe) {
3574 getFirstprivateOperandsMutable().append(op.getResult());
3575
3577
3578 if (getFirstprivatizationRecipesAttr())
3579 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3580
3581 recipes.push_back(
3582 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3583 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3584}
3585
3586void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3587 mlir::acc::ReductionRecipeOp recipe) {
3588 getReductionOperandsMutable().append(op.getResult());
3589
3591
3592 if (getReductionRecipesAttr())
3593 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3594
3595 recipes.push_back(
3596 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3597 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3598}
3599
3600//===----------------------------------------------------------------------===//
3601// DataOp
3602//===----------------------------------------------------------------------===//
3603
3604LogicalResult acc::DataOp::verify() {
3605 // 2.6.5. Data Construct restriction
3606 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3607 // attach, or default clause must appear on a data construct.
3608 if (getOperands().empty() && !getDefaultAttr())
3609 return emitError("at least one operand or the default attribute "
3610 "must appear on the data operation");
3611
3612 for (mlir::Value operand : getDataClauseOperands())
3613 if (isa<BlockArgument>(operand) ||
3614 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3615 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3616 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3617 operand.getDefiningOp()))
3618 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3619 "as defining op");
3620
3622 return failure();
3623
3624 return success();
3625}
3626
3627unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3628
3629Value DataOp::getDataOperand(unsigned i) {
3630 unsigned numOptional = getIfCond() ? 1 : 0;
3631 numOptional += getAsyncOperands().size() ? 1 : 0;
3632 numOptional += getWaitOperands().size();
3633 return getOperand(numOptional + i);
3634}
3635
3636bool acc::DataOp::hasAsyncOnly() {
3637 return hasAsyncOnly(mlir::acc::DeviceType::None);
3638}
3639
3640bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3641 return hasDeviceType(getAsyncOnly(), deviceType);
3642}
3643
3644mlir::Value DataOp::getAsyncValue() {
3645 return getAsyncValue(mlir::acc::DeviceType::None);
3646}
3647
3648mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3650 getAsyncOperands(), deviceType);
3651}
3652
3653bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3654
3655bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3656 return hasDeviceType(getWaitOnly(), deviceType);
3657}
3658
3659mlir::Operation::operand_range DataOp::getWaitValues() {
3660 return getWaitValues(mlir::acc::DeviceType::None);
3661}
3662
3664DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3666 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3667 getHasWaitDevnum(), deviceType);
3668}
3669
3670mlir::Value DataOp::getWaitDevnum() {
3671 return getWaitDevnum(mlir::acc::DeviceType::None);
3672}
3673
3674mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3675 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3676 getWaitOperandsSegments(), getHasWaitDevnum(),
3677 deviceType);
3678}
3679
3680void acc::DataOp::addAsyncOnly(
3681 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3682 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3683 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3684}
3685
3686void acc::DataOp::addAsyncOperand(
3687 MLIRContext *context, mlir::Value newValue,
3688 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3689 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3690 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3691 getAsyncOperandsMutable()));
3692}
3693
3694void acc::DataOp::addWaitOnly(MLIRContext *context,
3695 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3696 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3697 effectiveDeviceTypes));
3698}
3699
3700void acc::DataOp::addWaitOperands(
3701 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3702 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3703
3705 if (getWaitOperandsSegments())
3706 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3707
3708 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3709 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3710 getWaitOperandsMutable(), segments));
3711 setWaitOperandsSegments(segments);
3712
3714 if (getHasWaitDevnumAttr())
3715 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3716 hasDevnums.insert(
3717 hasDevnums.end(),
3718 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3719 mlir::BoolAttr::get(context, hasDevnum));
3720 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3721}
3722
3723//===----------------------------------------------------------------------===//
3724// ExitDataOp
3725//===----------------------------------------------------------------------===//
3726
3727LogicalResult acc::ExitDataOp::verify() {
3728 // 2.6.6. Data Exit Directive restriction
3729 // At least one copyout, delete, or detach clause must appear on an exit data
3730 // directive.
3731 if (getDataClauseOperands().empty())
3732 return emitError("at least one operand must be present in dataOperands on "
3733 "the exit data operation");
3734
3735 // The async attribute represent the async clause without value. Therefore the
3736 // attribute and operand cannot appear at the same time.
3737 if (getAsyncOperand() && getAsync())
3738 return emitError("async attribute cannot appear with asyncOperand");
3739
3740 // The wait attribute represent the wait clause without values. Therefore the
3741 // attribute and operands cannot appear at the same time.
3742 if (!getWaitOperands().empty() && getWait())
3743 return emitError("wait attribute cannot appear with waitOperands");
3744
3745 if (getWaitDevnum() && getWaitOperands().empty())
3746 return emitError("wait_devnum cannot appear without waitOperands");
3747
3748 return success();
3749}
3750
3751unsigned ExitDataOp::getNumDataOperands() {
3752 return getDataClauseOperands().size();
3753}
3754
3755Value ExitDataOp::getDataOperand(unsigned i) {
3756 unsigned numOptional = getIfCond() ? 1 : 0;
3757 numOptional += getAsyncOperand() ? 1 : 0;
3758 numOptional += getWaitDevnum() ? 1 : 0;
3759 return getOperand(getWaitOperands().size() + numOptional + i);
3760}
3761
3762void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3763 MLIRContext *context) {
3764 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3765}
3766
3767void ExitDataOp::addAsyncOnly(MLIRContext *context,
3768 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3769 assert(effectiveDeviceTypes.empty());
3770 assert(!getAsyncAttr());
3771 assert(!getAsyncOperand());
3772
3773 setAsyncAttr(mlir::UnitAttr::get(context));
3774}
3775
3776void ExitDataOp::addAsyncOperand(
3777 MLIRContext *context, mlir::Value newValue,
3778 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3779 assert(effectiveDeviceTypes.empty());
3780 assert(!getAsyncAttr());
3781 assert(!getAsyncOperand());
3782
3783 getAsyncOperandMutable().append(newValue);
3784}
3785
3786void ExitDataOp::addWaitOnly(MLIRContext *context,
3787 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3788 assert(effectiveDeviceTypes.empty());
3789 assert(!getWaitAttr());
3790 assert(getWaitOperands().empty());
3791 assert(!getWaitDevnum());
3792
3793 setWaitAttr(mlir::UnitAttr::get(context));
3794}
3795
3796void ExitDataOp::addWaitOperands(
3797 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3798 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3799 assert(effectiveDeviceTypes.empty());
3800 assert(!getWaitAttr());
3801 assert(getWaitOperands().empty());
3802 assert(!getWaitDevnum());
3803
3804 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3805 // operands list.
3806 if (hasDevnum) {
3807 getWaitDevnumMutable().append(newValues.front());
3808 newValues = newValues.drop_front();
3809 }
3810
3811 getWaitOperandsMutable().append(newValues);
3812}
3813
3814//===----------------------------------------------------------------------===//
3815// EnterDataOp
3816//===----------------------------------------------------------------------===//
3817
3818LogicalResult acc::EnterDataOp::verify() {
3819 // 2.6.6. Data Enter Directive restriction
3820 // At least one copyin, create, or attach clause must appear on an enter data
3821 // directive.
3822 if (getDataClauseOperands().empty())
3823 return emitError("at least one operand must be present in dataOperands on "
3824 "the enter data operation");
3825
3826 // The async attribute represent the async clause without value. Therefore the
3827 // attribute and operand cannot appear at the same time.
3828 if (getAsyncOperand() && getAsync())
3829 return emitError("async attribute cannot appear with asyncOperand");
3830
3831 // The wait attribute represent the wait clause without values. Therefore the
3832 // attribute and operands cannot appear at the same time.
3833 if (!getWaitOperands().empty() && getWait())
3834 return emitError("wait attribute cannot appear with waitOperands");
3835
3836 if (getWaitDevnum() && getWaitOperands().empty())
3837 return emitError("wait_devnum cannot appear without waitOperands");
3838
3839 for (mlir::Value operand : getDataClauseOperands())
3840 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3841 operand.getDefiningOp()))
3842 return emitError("expect data entry operation as defining op");
3843
3844 return success();
3845}
3846
3847unsigned EnterDataOp::getNumDataOperands() {
3848 return getDataClauseOperands().size();
3849}
3850
3851Value EnterDataOp::getDataOperand(unsigned i) {
3852 unsigned numOptional = getIfCond() ? 1 : 0;
3853 numOptional += getAsyncOperand() ? 1 : 0;
3854 numOptional += getWaitDevnum() ? 1 : 0;
3855 return getOperand(getWaitOperands().size() + numOptional + i);
3856}
3857
3858void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3859 MLIRContext *context) {
3860 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3861}
3862
3863void EnterDataOp::addAsyncOnly(
3864 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3865 assert(effectiveDeviceTypes.empty());
3866 assert(!getAsyncAttr());
3867 assert(!getAsyncOperand());
3868
3869 setAsyncAttr(mlir::UnitAttr::get(context));
3870}
3871
3872void EnterDataOp::addAsyncOperand(
3873 MLIRContext *context, mlir::Value newValue,
3874 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3875 assert(effectiveDeviceTypes.empty());
3876 assert(!getAsyncAttr());
3877 assert(!getAsyncOperand());
3878
3879 getAsyncOperandMutable().append(newValue);
3880}
3881
3882void EnterDataOp::addWaitOnly(MLIRContext *context,
3883 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3884 assert(effectiveDeviceTypes.empty());
3885 assert(!getWaitAttr());
3886 assert(getWaitOperands().empty());
3887 assert(!getWaitDevnum());
3888
3889 setWaitAttr(mlir::UnitAttr::get(context));
3890}
3891
3892void EnterDataOp::addWaitOperands(
3893 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3894 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getWaitAttr());
3897 assert(getWaitOperands().empty());
3898 assert(!getWaitDevnum());
3899
3900 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3901 // operands list.
3902 if (hasDevnum) {
3903 getWaitDevnumMutable().append(newValues.front());
3904 newValues = newValues.drop_front();
3905 }
3906
3907 getWaitOperandsMutable().append(newValues);
3908}
3909
3910//===----------------------------------------------------------------------===//
3911// AtomicReadOp
3912//===----------------------------------------------------------------------===//
3913
3914LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3915
3916//===----------------------------------------------------------------------===//
3917// AtomicWriteOp
3918//===----------------------------------------------------------------------===//
3919
3920LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3921
3922//===----------------------------------------------------------------------===//
3923// AtomicUpdateOp
3924//===----------------------------------------------------------------------===//
3925
3926LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3927 PatternRewriter &rewriter) {
3928 if (op.isNoOp()) {
3929 rewriter.eraseOp(op);
3930 return success();
3931 }
3932
3933 if (Value writeVal = op.getWriteOpVal()) {
3934 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
3935 op.getIfCond());
3936 return success();
3937 }
3938
3939 return failure();
3940}
3941
3942LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3943
3944LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3945
3946//===----------------------------------------------------------------------===//
3947// AtomicCaptureOp
3948//===----------------------------------------------------------------------===//
3949
3950AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3951 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3952 return op;
3953 return dyn_cast<AtomicReadOp>(getSecondOp());
3954}
3955
3956AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3957 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3958 return op;
3959 return dyn_cast<AtomicWriteOp>(getSecondOp());
3960}
3961
3962AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3963 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3964 return op;
3965 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3966}
3967
3968LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3969
3970//===----------------------------------------------------------------------===//
3971// DeclareEnterOp
3972//===----------------------------------------------------------------------===//
3973
3974template <typename Op>
3975static LogicalResult
3977 bool requireAtLeastOneOperand = true) {
3978 if (operands.empty() && requireAtLeastOneOperand)
3979 return emitError(
3980 op->getLoc(),
3981 "at least one operand must appear on the declare operation");
3982
3983 for (mlir::Value operand : operands) {
3984 if (isa<BlockArgument>(operand) ||
3985 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3986 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3987 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3988 operand.getDefiningOp()))
3989 return op.emitError(
3990 "expect valid declare data entry operation or acc.getdeviceptr "
3991 "as defining op");
3992
3993 mlir::Value var{getVar(operand.getDefiningOp())};
3994 assert(var && "declare operands can only be data entry operations which "
3995 "must have var");
3996 (void)var;
3997 std::optional<mlir::acc::DataClause> dataClauseOptional{
3998 getDataClause(operand.getDefiningOp())};
3999 assert(dataClauseOptional.has_value() &&
4000 "declare operands can only be data entry operations which must have "
4001 "dataClause");
4002 (void)dataClauseOptional;
4003 }
4004
4005 return success();
4006}
4007
4008LogicalResult acc::DeclareEnterOp::verify() {
4009 return checkDeclareOperands(*this, this->getDataClauseOperands());
4010}
4011
4012//===----------------------------------------------------------------------===//
4013// DeclareExitOp
4014//===----------------------------------------------------------------------===//
4015
4016LogicalResult acc::DeclareExitOp::verify() {
4017 if (getToken())
4018 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4019 /*requireAtLeastOneOperand=*/false);
4020 return checkDeclareOperands(*this, this->getDataClauseOperands());
4021}
4022
4023//===----------------------------------------------------------------------===//
4024// DeclareOp
4025//===----------------------------------------------------------------------===//
4026
4027LogicalResult acc::DeclareOp::verify() {
4028 return checkDeclareOperands(*this, this->getDataClauseOperands());
4029}
4030
4031//===----------------------------------------------------------------------===//
4032// RoutineOp
4033//===----------------------------------------------------------------------===//
4034
4035static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4036 acc::DeviceType dtype) {
4037 unsigned parallelism = 0;
4038 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4039 parallelism += op.hasWorker(dtype) ? 1 : 0;
4040 parallelism += op.hasVector(dtype) ? 1 : 0;
4041 parallelism += op.hasSeq(dtype) ? 1 : 0;
4042 return parallelism;
4043}
4044
4045LogicalResult acc::RoutineOp::verify() {
4046 unsigned baseParallelism =
4047 getParallelismForDeviceType(*this, acc::DeviceType::None);
4048
4049 if (baseParallelism > 1)
4050 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4051 "be present at the same time";
4052
4053 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4054 ++dtypeInt) {
4055 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4056 if (dtype == acc::DeviceType::None)
4057 continue;
4058 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4059
4060 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4061 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4062 "be present at the same time";
4063 }
4064
4065 return success();
4066}
4067
4068static ParseResult parseBindName(OpAsmParser &parser,
4069 mlir::ArrayAttr &bindIdName,
4070 mlir::ArrayAttr &bindStrName,
4071 mlir::ArrayAttr &deviceIdTypes,
4072 mlir::ArrayAttr &deviceStrTypes) {
4073 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4074 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4075 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4076 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4077
4078 if (failed(parser.parseCommaSeparatedList([&]() {
4079 mlir::Attribute newAttr;
4080 bool isSymbolRefAttr;
4081 auto parseResult = parser.parseAttribute(newAttr);
4082 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4083 bindIdNameAttrs.push_back(symbolRefAttr);
4084 isSymbolRefAttr = true;
4085 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4086 bindStrNameAttrs.push_back(stringAttr);
4087 isSymbolRefAttr = false;
4088 }
4089 if (parseResult)
4090 return failure();
4091 if (failed(parser.parseOptionalLSquare())) {
4092 if (isSymbolRefAttr) {
4093 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4094 parser.getContext(), mlir::acc::DeviceType::None));
4095 } else {
4096 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4097 parser.getContext(), mlir::acc::DeviceType::None));
4098 }
4099 } else {
4100 if (isSymbolRefAttr) {
4101 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4102 parser.parseRSquare())
4103 return failure();
4104 } else {
4105 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4106 parser.parseRSquare())
4107 return failure();
4108 }
4109 }
4110 return success();
4111 })))
4112 return failure();
4113
4114 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4115 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4116 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4117 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4118
4119 return success();
4120}
4121
4123 std::optional<mlir::ArrayAttr> bindIdName,
4124 std::optional<mlir::ArrayAttr> bindStrName,
4125 std::optional<mlir::ArrayAttr> deviceIdTypes,
4126 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4127 // Create combined vectors for all bind names and device types
4130
4131 // Append bindIdName and deviceIdTypes
4132 if (hasDeviceTypeValues(deviceIdTypes)) {
4133 allBindNames.append(bindIdName->begin(), bindIdName->end());
4134 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4135 }
4136
4137 // Append bindStrName and deviceStrTypes
4138 if (hasDeviceTypeValues(deviceStrTypes)) {
4139 allBindNames.append(bindStrName->begin(), bindStrName->end());
4140 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4141 }
4142
4143 // Print the combined sequence
4144 if (!allBindNames.empty())
4145 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4146 [&](const auto &pair) {
4147 p << std::get<0>(pair);
4148 printSingleDeviceType(p, std::get<1>(pair));
4149 });
4150}
4151
4152static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4153 mlir::ArrayAttr &gang,
4154 mlir::ArrayAttr &gangDim,
4155 mlir::ArrayAttr &gangDimDeviceTypes) {
4156
4157 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4158 gangDimDeviceTypeAttrs;
4159 bool needCommaBeforeOperands = false;
4160
4161 // Gang keyword only
4162 if (failed(parser.parseOptionalLParen())) {
4163 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4164 parser.getContext(), mlir::acc::DeviceType::None));
4165 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4166 return success();
4167 }
4168
4169 // Parse keyword only attributes
4170 if (succeeded(parser.parseOptionalLSquare())) {
4171 if (failed(parser.parseCommaSeparatedList([&]() {
4172 if (parser.parseAttribute(gangAttrs.emplace_back()))
4173 return failure();
4174 return success();
4175 })))
4176 return failure();
4177 if (parser.parseRSquare())
4178 return failure();
4179 needCommaBeforeOperands = true;
4180 }
4181
4182 if (needCommaBeforeOperands && failed(parser.parseComma()))
4183 return failure();
4184
4185 if (failed(parser.parseCommaSeparatedList([&]() {
4186 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4187 parser.parseColon() ||
4188 parser.parseAttribute(gangDimAttrs.emplace_back()))
4189 return failure();
4190 if (succeeded(parser.parseOptionalLSquare())) {
4191 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4192 parser.parseRSquare())
4193 return failure();
4194 } else {
4195 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4196 parser.getContext(), mlir::acc::DeviceType::None));
4197 }
4198 return success();
4199 })))
4200 return failure();
4201
4202 if (failed(parser.parseRParen()))
4203 return failure();
4204
4205 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4206 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4207 gangDimDeviceTypes =
4208 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4209
4210 return success();
4211}
4212
4214 std::optional<mlir::ArrayAttr> gang,
4215 std::optional<mlir::ArrayAttr> gangDim,
4216 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4217
4218 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4219 gang->size() == 1) {
4220 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4221 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4222 return;
4223 }
4224
4225 p << "(";
4226
4227 printDeviceTypes(p, gang);
4228
4229 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4230 p << ", ";
4231
4232 if (hasDeviceTypeValues(gangDimDeviceTypes))
4233 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4234 [&](const auto &pair) {
4235 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4236 p << std::get<0>(pair);
4237 printSingleDeviceType(p, std::get<1>(pair));
4238 });
4239
4240 p << ")";
4241}
4242
4243static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4244 mlir::ArrayAttr &deviceTypes) {
4246 // Keyword only
4247 if (failed(parser.parseOptionalLParen())) {
4248 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4249 parser.getContext(), mlir::acc::DeviceType::None));
4250 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4251 return success();
4252 }
4253
4254 // Parse device type attributes
4255 if (succeeded(parser.parseOptionalLSquare())) {
4256 if (failed(parser.parseCommaSeparatedList([&]() {
4257 if (parser.parseAttribute(attributes.emplace_back()))
4258 return failure();
4259 return success();
4260 })))
4261 return failure();
4262 if (parser.parseRSquare() || parser.parseRParen())
4263 return failure();
4264 }
4265 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4266 return success();
4267}
4268
4269static void
4271 std::optional<mlir::ArrayAttr> deviceTypes) {
4272
4273 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4274 auto deviceTypeAttr =
4275 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4276 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4277 return;
4278 }
4279
4280 if (!hasDeviceTypeValues(deviceTypes))
4281 return;
4282
4283 p << "([";
4284 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4285 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4286 p << dTypeAttr;
4287 });
4288 p << "])";
4289}
4290
4291bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4292
4293bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4294 return hasDeviceType(getWorker(), deviceType);
4295}
4296
4297bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4298
4299bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4300 return hasDeviceType(getVector(), deviceType);
4301}
4302
4303bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4304
4305bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4306 return hasDeviceType(getSeq(), deviceType);
4307}
4308
4309std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4310RoutineOp::getBindNameValue() {
4311 return getBindNameValue(mlir::acc::DeviceType::None);
4312}
4313
4314std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4315RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4316 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4317 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4318 return std::nullopt;
4319 }
4320
4321 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4322 auto attr = (*getBindIdName())[*pos];
4323 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4324 assert(symbolRefAttr && "expected SymbolRef");
4325 return symbolRefAttr;
4326 }
4327
4328 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4329 auto attr = (*getBindStrName())[*pos];
4330 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4331 assert(stringAttr && "expected String");
4332 return stringAttr;
4333 }
4334
4335 return std::nullopt;
4336}
4337
4338bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4339
4340bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4341 return hasDeviceType(getGang(), deviceType);
4342}
4343
4344std::optional<int64_t> RoutineOp::getGangDimValue() {
4345 return getGangDimValue(mlir::acc::DeviceType::None);
4346}
4347
4348std::optional<int64_t>
4349RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4350 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4351 return std::nullopt;
4352 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4353 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4354 return intAttr.getInt();
4355 }
4356 return std::nullopt;
4357}
4358
4359//===----------------------------------------------------------------------===//
4360// InitOp
4361//===----------------------------------------------------------------------===//
4362
4363LogicalResult acc::InitOp::verify() {
4364 Operation *currOp = *this;
4365 while ((currOp = currOp->getParentOp()))
4366 if (isComputeOperation(currOp))
4367 return emitOpError("cannot be nested in a compute operation");
4368 return success();
4369}
4370
4371void acc::InitOp::addDeviceType(MLIRContext *context,
4372 mlir::acc::DeviceType deviceType) {
4374 if (getDeviceTypesAttr())
4375 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4376
4377 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4378 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4379}
4380
4381//===----------------------------------------------------------------------===//
4382// ShutdownOp
4383//===----------------------------------------------------------------------===//
4384
4385LogicalResult acc::ShutdownOp::verify() {
4386 Operation *currOp = *this;
4387 while ((currOp = currOp->getParentOp()))
4388 if (isComputeOperation(currOp))
4389 return emitOpError("cannot be nested in a compute operation");
4390 return success();
4391}
4392
4393void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4394 mlir::acc::DeviceType deviceType) {
4396 if (getDeviceTypesAttr())
4397 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4398
4399 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4400 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4401}
4402
4403//===----------------------------------------------------------------------===//
4404// SetOp
4405//===----------------------------------------------------------------------===//
4406
4407LogicalResult acc::SetOp::verify() {
4408 Operation *currOp = *this;
4409 while ((currOp = currOp->getParentOp()))
4410 if (isComputeOperation(currOp))
4411 return emitOpError("cannot be nested in a compute operation");
4412 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4413 return emitOpError("at least one default_async, device_num, or device_type "
4414 "operand must appear");
4415 return success();
4416}
4417
4418//===----------------------------------------------------------------------===//
4419// UpdateOp
4420//===----------------------------------------------------------------------===//
4421
4422LogicalResult acc::UpdateOp::verify() {
4423 // At least one of host or device should have a value.
4424 if (getDataClauseOperands().empty())
4425 return emitError("at least one value must be present in dataOperands");
4426
4428 getAsyncOperandsDeviceTypeAttr(),
4429 "async")))
4430 return failure();
4431
4433 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4434 getWaitOperandsDeviceTypeAttr(), "wait")))
4435 return failure();
4436
4438 return failure();
4439
4440 for (mlir::Value operand : getDataClauseOperands())
4441 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4442 operand.getDefiningOp()))
4443 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4444 "as defining op");
4445
4446 return success();
4447}
4448
4449unsigned UpdateOp::getNumDataOperands() {
4450 return getDataClauseOperands().size();
4451}
4452
4453Value UpdateOp::getDataOperand(unsigned i) {
4454 unsigned numOptional = getAsyncOperands().size();
4455 numOptional += getIfCond() ? 1 : 0;
4456 return getOperand(getWaitOperands().size() + numOptional + i);
4457}
4458
4459void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4460 MLIRContext *context) {
4461 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4462}
4463
4464bool UpdateOp::hasAsyncOnly() {
4465 return hasAsyncOnly(mlir::acc::DeviceType::None);
4466}
4467
4468bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4469 return hasDeviceType(getAsyncOnly(), deviceType);
4470}
4471
4472mlir::Value UpdateOp::getAsyncValue() {
4473 return getAsyncValue(mlir::acc::DeviceType::None);
4474}
4475
4476mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4478 return {};
4479
4480 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4481 return getAsyncOperands()[*pos];
4482
4483 return {};
4484}
4485
4486bool UpdateOp::hasWaitOnly() {
4487 return hasWaitOnly(mlir::acc::DeviceType::None);
4488}
4489
4490bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4491 return hasDeviceType(getWaitOnly(), deviceType);
4492}
4493
4494mlir::Operation::operand_range UpdateOp::getWaitValues() {
4495 return getWaitValues(mlir::acc::DeviceType::None);
4496}
4497
4499UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4501 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4502 getHasWaitDevnum(), deviceType);
4503}
4504
4505mlir::Value UpdateOp::getWaitDevnum() {
4506 return getWaitDevnum(mlir::acc::DeviceType::None);
4507}
4508
4509mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4510 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4511 getWaitOperandsSegments(), getHasWaitDevnum(),
4512 deviceType);
4513}
4514
4515void UpdateOp::addAsyncOnly(MLIRContext *context,
4516 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4517 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4518 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4519}
4520
4521void UpdateOp::addAsyncOperand(
4522 MLIRContext *context, mlir::Value newValue,
4523 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4524 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4525 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4526 getAsyncOperandsMutable()));
4527}
4528
4529void UpdateOp::addWaitOnly(MLIRContext *context,
4530 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4531 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4532 effectiveDeviceTypes));
4533}
4534
4535void UpdateOp::addWaitOperands(
4536 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4537 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4538
4540 if (getWaitOperandsSegments())
4541 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4542
4543 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4544 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4545 getWaitOperandsMutable(), segments));
4546 setWaitOperandsSegments(segments);
4547
4549 if (getHasWaitDevnumAttr())
4550 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4551 hasDevnums.insert(
4552 hasDevnums.end(),
4553 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4554 mlir::BoolAttr::get(context, hasDevnum));
4555 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4556}
4557
4558//===----------------------------------------------------------------------===//
4559// WaitOp
4560//===----------------------------------------------------------------------===//
4561
4562LogicalResult acc::WaitOp::verify() {
4563 // The async attribute represent the async clause without value. Therefore the
4564 // attribute and operand cannot appear at the same time.
4565 if (getAsyncOperand() && getAsync())
4566 return emitError("async attribute cannot appear with asyncOperand");
4567
4568 if (getWaitDevnum() && getWaitOperands().empty())
4569 return emitError("wait_devnum cannot appear without waitOperands");
4570
4571 return success();
4572}
4573
4574#define GET_OP_CLASSES
4575#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4576
4577#define GET_ATTRDEF_CLASSES
4578#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4579
4580#define GET_TYPEDEF_CLASSES
4581#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4582
4583//===----------------------------------------------------------------------===//
4584// acc dialect utilities
4585//===----------------------------------------------------------------------===//
4586
4589 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4591 accDataClauseOp)
4592 .Case<ACC_DATA_ENTRY_OPS>(
4593 [&](auto entry) { return entry.getVarPtr(); })
4594 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4595 [&](auto exit) { return exit.getVarPtr(); })
4596 .Default([&](mlir::Operation *) {
4598 })};
4599 return varPtr;
4600}
4601
4603 auto varPtr{
4605 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4606 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4607 return varPtr;
4608}
4609
4611 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4612 .Case<ACC_DATA_ENTRY_OPS>(
4613 [&](auto entry) { return entry.getVarType(); })
4614 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4615 [&](auto exit) { return exit.getVarType(); })
4616 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4617 return varType;
4618}
4619
4622 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4624 accDataClauseOp)
4625 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4626 [&](auto dataClause) { return dataClause.getAccPtr(); })
4627 .Default([&](mlir::Operation *) {
4629 })};
4630 return accPtr;
4631}
4632
4634 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4636 [&](auto dataClause) { return dataClause.getAccVar(); })
4637 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4638 return accPtr;
4639}
4640
4642 auto varPtrPtr{
4644 .Case<ACC_DATA_ENTRY_OPS>(
4645 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4646 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4647 return varPtrPtr;
4648}
4649
4654 accDataClauseOp)
4655 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4657 dataClause.getBounds().begin(), dataClause.getBounds().end());
4658 })
4659 .Default([&](mlir::Operation *) {
4661 })};
4662 return bounds;
4663}
4664
4668 accDataClauseOp)
4669 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4671 dataClause.getAsyncOperands().begin(),
4672 dataClause.getAsyncOperands().end());
4673 })
4674 .Default([&](mlir::Operation *) {
4676 });
4677}
4678
4679mlir::ArrayAttr
4682 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4683 return dataClause.getAsyncOperandsDeviceTypeAttr();
4684 })
4685 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4686}
4687
4688mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4691 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4692 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4693}
4694
4695std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4696 auto name{
4698 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4699 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4700 return {};
4701 })};
4702 return name;
4703}
4704
4705std::optional<mlir::acc::DataClause>
4707 auto dataClause{
4709 accDataEntryOp)
4710 .Case<ACC_DATA_ENTRY_OPS>(
4711 [&](auto entry) { return entry.getDataClause(); })
4712 .Default([&](mlir::Operation *) { return std::nullopt; })};
4713 return dataClause;
4714}
4715
4717 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4718 .Case<ACC_DATA_ENTRY_OPS>(
4719 [&](auto entry) { return entry.getImplicit(); })
4720 .Default([&](mlir::Operation *) { return false; })};
4721 return implicit;
4722}
4723
4725 auto dataOperands{
4728 [&](auto entry) { return entry.getDataClauseOperands(); })
4729 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4730 return dataOperands;
4731}
4732
4735 auto dataOperands{
4738 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4739 .Default([&](mlir::Operation *) { return nullptr; })};
4740 return dataOperands;
4741}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:136
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4213
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:961
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:2956
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1545
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4068
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:2970
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:975
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:385
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2067
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:516
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:369
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:485
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2078
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:1983
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:311
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4270
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2765
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2318
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:3976
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:445
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2272
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:329
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:427
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:461
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3325
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1474
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2114
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1627
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:453
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:494
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:340
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:353
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1853
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:470
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3356
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4243
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4152
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:1966
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2141
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2257
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1920
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition OpenACC.cpp:1458
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2233
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:557
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2784
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1242
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2302
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:1897
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition OpenACC.cpp:1489
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2214
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:315
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:2911
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2152
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:528
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:405
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1555
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4035
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:1903
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2338
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition OpenACC.cpp:1438
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4122
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:53
false
Parses a map_entries map type from a string format back into its numeric value.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4633
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4602
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4621
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:4706
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:4734
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4651
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:4724
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4695
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:4716
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4666
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4641
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:180
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4688
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4610
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4588
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4680
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.