MLIR 22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
14#include "mlir/IR/Builders.h"
18#include "mlir/IR/Matchers.h"
20#include "mlir/IR/SymbolTable.h"
21#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/LogicalResult.h"
26#include <variant>
27
28using namespace mlir;
29using namespace acc;
30
31#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36
37namespace {
38
39static bool isScalarLikeType(Type type) {
40 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
41}
42
43/// Helper function to attach the `VarName` attribute to an operation
44/// if a variable name is provided.
45static void attachVarNameAttr(Operation *op, OpBuilder &builder,
46 StringRef varName) {
47 if (!varName.empty()) {
48 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
49 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
50 }
51}
52
53template <typename T>
54struct MemRefPointerLikeModel
55 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
56 Type getElementType(Type pointer) const {
57 return cast<T>(pointer).getElementType();
58 }
59
60 mlir::acc::VariableTypeCategory
61 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
62 Type varType) const {
63 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
64 return mappableTy.getTypeCategory(varPtr);
65 }
66 auto memrefTy = cast<T>(pointer);
67 if (!memrefTy.hasRank()) {
68 // This memref is unranked - aka it could have any rank, including a
69 // rank of 0 which could mean scalar. For now, return uncategorized.
70 return mlir::acc::VariableTypeCategory::uncategorized;
71 }
72
73 if (memrefTy.getRank() == 0) {
74 if (isScalarLikeType(memrefTy.getElementType())) {
75 return mlir::acc::VariableTypeCategory::scalar;
76 }
77 // Zero-rank non-scalar - need further analysis to determine the type
78 // category. For now, return uncategorized.
79 return mlir::acc::VariableTypeCategory::uncategorized;
80 }
81
82 // It has a rank - must be an array.
83 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
84 return mlir::acc::VariableTypeCategory::array;
85 }
86
87 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
88 StringRef varName, Type varType, Value originalVar,
89 bool &needsFree) const {
90 auto memrefTy = cast<MemRefType>(pointer);
91
92 // Check if this is a static memref (all dimensions are known) - if yes
93 // then we can generate an alloca operation.
94 if (memrefTy.hasStaticShape()) {
95 needsFree = false; // alloca doesn't need deallocation
96 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
97 attachVarNameAttr(allocaOp, builder, varName);
98 return allocaOp.getResult();
99 }
100
101 // For dynamic memrefs, extract sizes from the original variable if
102 // provided. Otherwise they cannot be handled.
103 if (originalVar && originalVar.getType() == memrefTy &&
104 memrefTy.hasRank()) {
105 SmallVector<Value> dynamicSizes;
106 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
107 if (memrefTy.isDynamicDim(i)) {
108 // Extract the size of dimension i from the original variable
109 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
110 auto dimSize =
111 memref::DimOp::create(builder, loc, originalVar, indexValue);
112 dynamicSizes.push_back(dimSize);
113 }
114 // Note: We only add dynamic sizes to the dynamicSizes array
115 // Static dimensions are handled automatically by AllocOp
116 }
117 needsFree = true; // alloc needs deallocation
118 auto allocOp =
119 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
120 attachVarNameAttr(allocOp, builder, varName);
121 return allocOp.getResult();
122 }
123
124 // TODO: Unranked not yet supported.
125 return {};
126 }
127
128 bool genFree(Type pointer, OpBuilder &builder, Location loc,
129 TypedValue<PointerLikeType> varToFree, Value allocRes,
130 Type varType) const {
131 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
132 // Use allocRes if provided to determine the allocation type
133 Value valueToInspect = allocRes ? allocRes : memrefValue;
134
135 // Walk through casts to find the original allocation
136 Value currentValue = valueToInspect;
137 Operation *originalAlloc = nullptr;
138
139 // Follow the chain of operations to find the original allocation
140 // even if a casted result is provided.
141 while (currentValue) {
142 if (auto *definingOp = currentValue.getDefiningOp()) {
143 // Check if this is an allocation operation
144 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
145 originalAlloc = definingOp;
146 break;
147 }
148
149 // Check if this is a cast operation we can look through
150 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
151 currentValue = castOp.getSource();
152 continue;
153 }
154
155 // Check for other cast-like operations
156 if (auto reinterpretCastOp =
157 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
158 currentValue = reinterpretCastOp.getSource();
159 continue;
160 }
161
162 // If we can't look through this operation, stop
163 break;
164 }
165 // This is a block argument or similar - can't trace further.
166 break;
167 }
168
169 if (originalAlloc) {
170 if (isa<memref::AllocaOp>(originalAlloc)) {
171 // This is an alloca - no dealloc needed, but return true (success)
172 return true;
173 }
174 if (isa<memref::AllocOp>(originalAlloc)) {
175 // This is an alloc - generate dealloc on varToFree
176 memref::DeallocOp::create(builder, loc, memrefValue);
177 return true;
178 }
179 }
180 }
181
182 return false;
183 }
184
185 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
186 TypedValue<PointerLikeType> destination,
187 TypedValue<PointerLikeType> source, Type varType) const {
188 // Generate a copy operation between two memrefs
189 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
190 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
191
192 // As per memref documentation, source and destination must have same
193 // element type and shape in order to be compatible. We do not want to fail
194 // with an IR verification error - thus check that before generating the
195 // copy operation.
196 if (destMemref && srcMemref &&
197 destMemref.getType().getElementType() ==
198 srcMemref.getType().getElementType() &&
199 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
200 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
201 return true;
202 }
203
204 return false;
205 }
206};
207
208struct LLVMPointerPointerLikeModel
209 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
210 LLVM::LLVMPointerType> {
211 Type getElementType(Type pointer) const { return Type(); }
212};
213
214struct MemrefAddressOfGlobalModel
215 : public AddressOfGlobalOpInterface::ExternalModel<
216 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
217 SymbolRefAttr getSymbol(Operation *op) const {
218 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
219 return getGlobalOp.getNameAttr();
220 }
221};
222
223struct MemrefGlobalVariableModel
224 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
225 memref::GlobalOp> {
226 bool isConstant(Operation *op) const {
227 auto globalOp = cast<memref::GlobalOp>(op);
228 return globalOp.getConstant();
229 }
230};
231
232/// Helper function for any of the times we need to modify an ArrayAttr based on
233/// a device type list. Returns a new ArrayAttr with all of the
234/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
235/// list is empty).
236mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
237 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
238 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
240 if (existingDeviceTypes)
241 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
242
243 if (newDeviceTypes.empty())
244 deviceTypes.push_back(
245 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
246
247 for (DeviceType dt : newDeviceTypes)
248 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
249
250 return mlir::ArrayAttr::get(context, deviceTypes);
251}
252
253/// Helper function for any of the times we need to add operands that are
254/// affected by a device type list. Returns a new ArrayAttr with all of the
255/// existingDeviceTypes, plus the effective new ones (or an added none, if the
256/// new list is empty). Additionally, adds the arguments to the argCollection
257/// the correct number of times. This will also update a 'segments' array, even
258/// if it won't be used.
259mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
260 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
261 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
262 mlir::MutableOperandRange argCollection,
263 llvm::SmallVector<int32_t> &segments) {
265 if (existingDeviceTypes)
266 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
267
268 if (newDeviceTypes.empty()) {
269 argCollection.append(arguments);
270 segments.push_back(arguments.size());
271 deviceTypes.push_back(
272 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
273 }
274
275 for (DeviceType dt : newDeviceTypes) {
276 argCollection.append(arguments);
277 segments.push_back(arguments.size());
278 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
279 }
280
281 return mlir::ArrayAttr::get(context, deviceTypes);
282}
283
284/// Overload for when the 'segments' aren't needed.
285mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
286 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
287 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
288 mlir::MutableOperandRange argCollection) {
290 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
291 newDeviceTypes, arguments,
292 argCollection, segments);
293}
294} // namespace
295
296//===----------------------------------------------------------------------===//
297// OpenACC operations
298//===----------------------------------------------------------------------===//
299
300void OpenACCDialect::initialize() {
301 addOperations<
302#define GET_OP_LIST
303#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
304 >();
305 addAttributes<
306#define GET_ATTRDEF_LIST
307#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
308 >();
309 addTypes<
310#define GET_TYPEDEF_LIST
311#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
312 >();
313
314 // By attaching interfaces here, we make the OpenACC dialect dependent on
315 // the other dialects. This is probably better than having dialects like LLVM
316 // and memref be dependent on OpenACC.
317 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
318 *getContext());
319 UnrankedMemRefType::attachInterface<
320 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
321 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
322 *getContext());
323
324 // Attach operation interfaces
325 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
326 *getContext());
327 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
328}
329
330//===----------------------------------------------------------------------===//
331// device_type support helpers
332//===----------------------------------------------------------------------===//
333
334static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
335 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
336}
337
338static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
339 mlir::acc::DeviceType deviceType) {
340 if (!hasDeviceTypeValues(arrayAttr))
341 return false;
342
343 for (auto attr : *arrayAttr) {
344 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
345 if (deviceTypeAttr.getValue() == deviceType)
346 return true;
347 }
348
349 return false;
350}
351
353 std::optional<mlir::ArrayAttr> deviceTypes) {
354 if (!hasDeviceTypeValues(deviceTypes))
355 return;
356
357 p << "[";
358 llvm::interleaveComma(*deviceTypes, p,
359 [&](mlir::Attribute attr) { p << attr; });
360 p << "]";
361}
362
363static std::optional<unsigned> findSegment(ArrayAttr segments,
364 mlir::acc::DeviceType deviceType) {
365 unsigned segmentIdx = 0;
366 for (auto attr : segments) {
367 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
368 if (deviceTypeAttr.getValue() == deviceType)
369 return std::make_optional(segmentIdx);
370 ++segmentIdx;
371 }
372 return std::nullopt;
373}
374
376getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
378 std::optional<llvm::ArrayRef<int32_t>> segments,
379 mlir::acc::DeviceType deviceType) {
380 if (!arrayAttr)
381 return range.take_front(0);
382 if (auto pos = findSegment(*arrayAttr, deviceType)) {
383 int32_t nbOperandsBefore = 0;
384 for (unsigned i = 0; i < *pos; ++i)
385 nbOperandsBefore += (*segments)[i];
386 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
387 }
388 return range.take_front(0);
389}
390
391static mlir::Value
392getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
394 std::optional<llvm::ArrayRef<int32_t>> segments,
395 std::optional<mlir::ArrayAttr> hasWaitDevnum,
396 mlir::acc::DeviceType deviceType) {
397 if (!hasDeviceTypeValues(deviceTypeAttr))
398 return {};
399 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
400 if (hasWaitDevnum->getValue()[*pos])
401 return getValuesFromSegments(deviceTypeAttr, operands, segments,
402 deviceType)
403 .front();
404 return {};
405}
406
408getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
410 std::optional<llvm::ArrayRef<int32_t>> segments,
411 std::optional<mlir::ArrayAttr> hasWaitDevnum,
412 mlir::acc::DeviceType deviceType) {
413 auto range =
414 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
415 if (range.empty())
416 return range;
417 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
418 if (hasWaitDevnum && *hasWaitDevnum) {
419 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
420 if (boolAttr.getValue())
421 return range.drop_front(1); // first value is devnum
422 }
423 }
424 return range;
425}
426
427template <typename Op>
428static LogicalResult checkWaitAndAsyncConflict(Op op) {
429 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
430 ++dtypeInt) {
431 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
432
433 // The asyncOnly attribute represent the async clause without value.
434 // Therefore the attribute and operand cannot appear at the same time.
435 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
436 op.hasAsyncOnly(dtype))
437 return op.emitError(
438 "asyncOnly attribute cannot appear with asyncOperand");
439
440 // The wait attribute represent the wait clause without values. Therefore
441 // the attribute and operands cannot appear at the same time.
442 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
443 op.hasWaitOnly(dtype))
444 return op.emitError("wait attribute cannot appear with waitOperands");
445 }
446 return success();
447}
448
449template <typename Op>
450static LogicalResult checkVarAndVarType(Op op) {
451 if (!op.getVar())
452 return op.emitError("must have var operand");
453
454 // A variable must have a type that is either pointer-like or mappable.
455 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
456 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
457 return op.emitError("var must be mappable or pointer-like");
458
459 // When it is a pointer-like type, the varType must capture the target type.
460 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
461 op.getVarType() == op.getVar().getType())
462 return op.emitError("varType must capture the element type of var");
463
464 return success();
465}
466
467template <typename Op>
468static LogicalResult checkVarAndAccVar(Op op) {
469 if (op.getVar().getType() != op.getAccVar().getType())
470 return op.emitError("input and output types must match");
471
472 return success();
473}
474
475template <typename Op>
476static LogicalResult checkNoModifier(Op op) {
477 if (op.getModifiers() != acc::DataClauseModifier::none)
478 return op.emitError("no data clause modifiers are allowed");
479 return success();
480}
481
482template <typename Op>
483static LogicalResult
484checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
485 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
486 return op.emitError(
487 "invalid data clause modifiers: " +
488 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
489
490 return success();
491}
492
493static ParseResult parseVar(mlir::OpAsmParser &parser,
495 // Either `var` or `varPtr` keyword is required.
496 if (failed(parser.parseOptionalKeyword("varPtr"))) {
497 if (failed(parser.parseKeyword("var")))
498 return failure();
499 }
500 if (failed(parser.parseLParen()))
501 return failure();
502 if (failed(parser.parseOperand(var)))
503 return failure();
504
505 return success();
506}
507
509 mlir::Value var) {
510 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
511 p << "varPtr(";
512 else
513 p << "var(";
514 p.printOperand(var);
515}
516
517static ParseResult parseAccVar(mlir::OpAsmParser &parser,
519 mlir::Type &accVarType) {
520 // Either `accVar` or `accPtr` keyword is required.
521 if (failed(parser.parseOptionalKeyword("accPtr"))) {
522 if (failed(parser.parseKeyword("accVar")))
523 return failure();
524 }
525 if (failed(parser.parseLParen()))
526 return failure();
527 if (failed(parser.parseOperand(var)))
528 return failure();
529 if (failed(parser.parseColon()))
530 return failure();
531 if (failed(parser.parseType(accVarType)))
532 return failure();
533 if (failed(parser.parseRParen()))
534 return failure();
535
536 return success();
537}
538
540 mlir::Value accVar, mlir::Type accVarType) {
541 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
542 p << "accPtr(";
543 else
544 p << "accVar(";
545 p.printOperand(accVar);
546 p << " : ";
547 p.printType(accVarType);
548 p << ")";
549}
550
551static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
552 mlir::Type &varPtrType,
553 mlir::TypeAttr &varTypeAttr) {
554 if (failed(parser.parseType(varPtrType)))
555 return failure();
556 if (failed(parser.parseRParen()))
557 return failure();
558
559 if (succeeded(parser.parseOptionalKeyword("varType"))) {
560 if (failed(parser.parseLParen()))
561 return failure();
562 mlir::Type varType;
563 if (failed(parser.parseType(varType)))
564 return failure();
565 varTypeAttr = mlir::TypeAttr::get(varType);
566 if (failed(parser.parseRParen()))
567 return failure();
568 } else {
569 // Set `varType` from the element type of the type of `varPtr`.
570 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
571 varTypeAttr = mlir::TypeAttr::get(
572 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
573 else
574 varTypeAttr = mlir::TypeAttr::get(varPtrType);
575 }
576
577 return success();
578}
579
581 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
582 p.printType(varPtrType);
583 p << ")";
584
585 // Print the `varType` only if it differs from the element type of
586 // `varPtr`'s type.
587 mlir::Type varType = varTypeAttr.getValue();
588 mlir::Type typeToCheckAgainst =
589 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
590 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
591 : varPtrType;
592 if (typeToCheckAgainst != varType) {
593 p << " varType(";
594 p.printType(varType);
595 p << ")";
596 }
597}
598
599//===----------------------------------------------------------------------===//
600// DataBoundsOp
601//===----------------------------------------------------------------------===//
602LogicalResult acc::DataBoundsOp::verify() {
603 auto extent = getExtent();
604 auto upperbound = getUpperbound();
605 if (!extent && !upperbound)
606 return emitError("expected extent or upperbound.");
607 return success();
608}
609
610//===----------------------------------------------------------------------===//
611// PrivateOp
612//===----------------------------------------------------------------------===//
613LogicalResult acc::PrivateOp::verify() {
614 if (getDataClause() != acc::DataClause::acc_private)
615 return emitError(
616 "data clause associated with private operation must match its intent");
617 if (failed(checkVarAndVarType(*this)))
618 return failure();
619 if (failed(checkNoModifier(*this)))
620 return failure();
621 return success();
622}
623
624//===----------------------------------------------------------------------===//
625// FirstprivateOp
626//===----------------------------------------------------------------------===//
627LogicalResult acc::FirstprivateOp::verify() {
628 if (getDataClause() != acc::DataClause::acc_firstprivate)
629 return emitError("data clause associated with firstprivate operation must "
630 "match its intent");
631 if (failed(checkVarAndVarType(*this)))
632 return failure();
633 if (failed(checkNoModifier(*this)))
634 return failure();
635 return success();
636}
637
638//===----------------------------------------------------------------------===//
639// FirstprivateMapInitialOp
640//===----------------------------------------------------------------------===//
641LogicalResult acc::FirstprivateMapInitialOp::verify() {
642 if (getDataClause() != acc::DataClause::acc_firstprivate)
643 return emitError("data clause associated with firstprivate operation must "
644 "match its intent");
645 if (failed(checkVarAndVarType(*this)))
646 return failure();
647 if (failed(checkNoModifier(*this)))
648 return failure();
649 return success();
650}
651
652//===----------------------------------------------------------------------===//
653// ReductionOp
654//===----------------------------------------------------------------------===//
655LogicalResult acc::ReductionOp::verify() {
656 if (getDataClause() != acc::DataClause::acc_reduction)
657 return emitError("data clause associated with reduction operation must "
658 "match its intent");
659 if (failed(checkVarAndVarType(*this)))
660 return failure();
661 if (failed(checkNoModifier(*this)))
662 return failure();
663 return success();
664}
665
666//===----------------------------------------------------------------------===//
667// DevicePtrOp
668//===----------------------------------------------------------------------===//
669LogicalResult acc::DevicePtrOp::verify() {
670 if (getDataClause() != acc::DataClause::acc_deviceptr)
671 return emitError("data clause associated with deviceptr operation must "
672 "match its intent");
673 if (failed(checkVarAndVarType(*this)))
674 return failure();
675 if (failed(checkVarAndAccVar(*this)))
676 return failure();
677 if (failed(checkNoModifier(*this)))
678 return failure();
679 return success();
680}
681
682//===----------------------------------------------------------------------===//
683// PresentOp
684//===----------------------------------------------------------------------===//
685LogicalResult acc::PresentOp::verify() {
686 if (getDataClause() != acc::DataClause::acc_present)
687 return emitError(
688 "data clause associated with present operation must match its intent");
689 if (failed(checkVarAndVarType(*this)))
690 return failure();
691 if (failed(checkVarAndAccVar(*this)))
692 return failure();
693 if (failed(checkNoModifier(*this)))
694 return failure();
695 return success();
696}
697
698//===----------------------------------------------------------------------===//
699// CopyinOp
700//===----------------------------------------------------------------------===//
701LogicalResult acc::CopyinOp::verify() {
702 // Test for all clauses this operation can be decomposed from:
703 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
704 getDataClause() != acc::DataClause::acc_copyin_readonly &&
705 getDataClause() != acc::DataClause::acc_copy &&
706 getDataClause() != acc::DataClause::acc_reduction)
707 return emitError(
708 "data clause associated with copyin operation must match its intent"
709 " or specify original clause this operation was decomposed from");
710 if (failed(checkVarAndVarType(*this)))
711 return failure();
712 if (failed(checkVarAndAccVar(*this)))
713 return failure();
714 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
715 acc::DataClauseModifier::always |
716 acc::DataClauseModifier::capture)))
717 return failure();
718 return success();
719}
720
721bool acc::CopyinOp::isCopyinReadonly() {
722 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
723 acc::bitEnumContainsAny(getModifiers(),
724 acc::DataClauseModifier::readonly);
725}
726
727//===----------------------------------------------------------------------===//
728// CreateOp
729//===----------------------------------------------------------------------===//
730LogicalResult acc::CreateOp::verify() {
731 // Test for all clauses this operation can be decomposed from:
732 if (getDataClause() != acc::DataClause::acc_create &&
733 getDataClause() != acc::DataClause::acc_create_zero &&
734 getDataClause() != acc::DataClause::acc_copyout &&
735 getDataClause() != acc::DataClause::acc_copyout_zero)
736 return emitError(
737 "data clause associated with create operation must match its intent"
738 " or specify original clause this operation was decomposed from");
739 if (failed(checkVarAndVarType(*this)))
740 return failure();
741 if (failed(checkVarAndAccVar(*this)))
742 return failure();
743 // this op is the entry part of copyout, so it also needs to allow all
744 // modifiers allowed on copyout.
745 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
746 acc::DataClauseModifier::always |
747 acc::DataClauseModifier::capture)))
748 return failure();
749 return success();
750}
751
752bool acc::CreateOp::isCreateZero() {
753 // The zero modifier is encoded in the data clause.
754 return getDataClause() == acc::DataClause::acc_create_zero ||
755 getDataClause() == acc::DataClause::acc_copyout_zero ||
756 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
757}
758
759//===----------------------------------------------------------------------===//
760// NoCreateOp
761//===----------------------------------------------------------------------===//
762LogicalResult acc::NoCreateOp::verify() {
763 if (getDataClause() != acc::DataClause::acc_no_create)
764 return emitError("data clause associated with no_create operation must "
765 "match its intent");
766 if (failed(checkVarAndVarType(*this)))
767 return failure();
768 if (failed(checkVarAndAccVar(*this)))
769 return failure();
770 if (failed(checkNoModifier(*this)))
771 return failure();
772 return success();
773}
774
775//===----------------------------------------------------------------------===//
776// AttachOp
777//===----------------------------------------------------------------------===//
778LogicalResult acc::AttachOp::verify() {
779 if (getDataClause() != acc::DataClause::acc_attach)
780 return emitError(
781 "data clause associated with attach operation must match its intent");
782 if (failed(checkVarAndVarType(*this)))
783 return failure();
784 if (failed(checkVarAndAccVar(*this)))
785 return failure();
786 if (failed(checkNoModifier(*this)))
787 return failure();
788 return success();
789}
790
791//===----------------------------------------------------------------------===//
792// DeclareDeviceResidentOp
793//===----------------------------------------------------------------------===//
794
795LogicalResult acc::DeclareDeviceResidentOp::verify() {
796 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
797 return emitError("data clause associated with device_resident operation "
798 "must match its intent");
799 if (failed(checkVarAndVarType(*this)))
800 return failure();
801 if (failed(checkVarAndAccVar(*this)))
802 return failure();
803 if (failed(checkNoModifier(*this)))
804 return failure();
805 return success();
806}
807
808//===----------------------------------------------------------------------===//
809// DeclareLinkOp
810//===----------------------------------------------------------------------===//
811
812LogicalResult acc::DeclareLinkOp::verify() {
813 if (getDataClause() != acc::DataClause::acc_declare_link)
814 return emitError(
815 "data clause associated with link operation must match its intent");
816 if (failed(checkVarAndVarType(*this)))
817 return failure();
818 if (failed(checkVarAndAccVar(*this)))
819 return failure();
820 if (failed(checkNoModifier(*this)))
821 return failure();
822 return success();
823}
824
825//===----------------------------------------------------------------------===//
826// CopyoutOp
827//===----------------------------------------------------------------------===//
828LogicalResult acc::CopyoutOp::verify() {
829 // Test for all clauses this operation can be decomposed from:
830 if (getDataClause() != acc::DataClause::acc_copyout &&
831 getDataClause() != acc::DataClause::acc_copyout_zero &&
832 getDataClause() != acc::DataClause::acc_copy &&
833 getDataClause() != acc::DataClause::acc_reduction)
834 return emitError(
835 "data clause associated with copyout operation must match its intent"
836 " or specify original clause this operation was decomposed from");
837 if (!getVar() || !getAccVar())
838 return emitError("must have both host and device pointers");
839 if (failed(checkVarAndVarType(*this)))
840 return failure();
841 if (failed(checkVarAndAccVar(*this)))
842 return failure();
843 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
844 acc::DataClauseModifier::always |
845 acc::DataClauseModifier::capture)))
846 return failure();
847 return success();
848}
849
850bool acc::CopyoutOp::isCopyoutZero() {
851 return getDataClause() == acc::DataClause::acc_copyout_zero ||
852 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
853}
854
855//===----------------------------------------------------------------------===//
856// DeleteOp
857//===----------------------------------------------------------------------===//
858LogicalResult acc::DeleteOp::verify() {
859 // Test for all clauses this operation can be decomposed from:
860 if (getDataClause() != acc::DataClause::acc_delete &&
861 getDataClause() != acc::DataClause::acc_create &&
862 getDataClause() != acc::DataClause::acc_create_zero &&
863 getDataClause() != acc::DataClause::acc_copyin &&
864 getDataClause() != acc::DataClause::acc_copyin_readonly &&
865 getDataClause() != acc::DataClause::acc_present &&
866 getDataClause() != acc::DataClause::acc_no_create &&
867 getDataClause() != acc::DataClause::acc_declare_device_resident &&
868 getDataClause() != acc::DataClause::acc_declare_link)
869 return emitError(
870 "data clause associated with delete operation must match its intent"
871 " or specify original clause this operation was decomposed from");
872 if (!getAccVar())
873 return emitError("must have device pointer");
874 // This op is the exit part of copyin and create - thus allow all modifiers
875 // allowed on either case.
876 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
877 acc::DataClauseModifier::readonly |
878 acc::DataClauseModifier::always |
879 acc::DataClauseModifier::capture)))
880 return failure();
881 return success();
882}
883
884//===----------------------------------------------------------------------===//
885// DetachOp
886//===----------------------------------------------------------------------===//
887LogicalResult acc::DetachOp::verify() {
888 // Test for all clauses this operation can be decomposed from:
889 if (getDataClause() != acc::DataClause::acc_detach &&
890 getDataClause() != acc::DataClause::acc_attach)
891 return emitError(
892 "data clause associated with detach operation must match its intent"
893 " or specify original clause this operation was decomposed from");
894 if (!getAccVar())
895 return emitError("must have device pointer");
896 if (failed(checkNoModifier(*this)))
897 return failure();
898 return success();
899}
900
901//===----------------------------------------------------------------------===//
902// HostOp
903//===----------------------------------------------------------------------===//
904LogicalResult acc::UpdateHostOp::verify() {
905 // Test for all clauses this operation can be decomposed from:
906 if (getDataClause() != acc::DataClause::acc_update_host &&
907 getDataClause() != acc::DataClause::acc_update_self)
908 return emitError(
909 "data clause associated with host operation must match its intent"
910 " or specify original clause this operation was decomposed from");
911 if (!getVar() || !getAccVar())
912 return emitError("must have both host and device pointers");
913 if (failed(checkVarAndVarType(*this)))
914 return failure();
915 if (failed(checkVarAndAccVar(*this)))
916 return failure();
917 if (failed(checkNoModifier(*this)))
918 return failure();
919 return success();
920}
921
922//===----------------------------------------------------------------------===//
923// DeviceOp
924//===----------------------------------------------------------------------===//
925LogicalResult acc::UpdateDeviceOp::verify() {
926 // Test for all clauses this operation can be decomposed from:
927 if (getDataClause() != acc::DataClause::acc_update_device)
928 return emitError(
929 "data clause associated with device operation must match its intent"
930 " or specify original clause this operation was decomposed from");
931 if (failed(checkVarAndVarType(*this)))
932 return failure();
933 if (failed(checkVarAndAccVar(*this)))
934 return failure();
935 if (failed(checkNoModifier(*this)))
936 return failure();
937 return success();
938}
939
940//===----------------------------------------------------------------------===//
941// UseDeviceOp
942//===----------------------------------------------------------------------===//
943LogicalResult acc::UseDeviceOp::verify() {
944 // Test for all clauses this operation can be decomposed from:
945 if (getDataClause() != acc::DataClause::acc_use_device)
946 return emitError(
947 "data clause associated with use_device operation must match its intent"
948 " or specify original clause this operation was decomposed from");
949 if (failed(checkVarAndVarType(*this)))
950 return failure();
951 if (failed(checkVarAndAccVar(*this)))
952 return failure();
953 if (failed(checkNoModifier(*this)))
954 return failure();
955 return success();
956}
957
958//===----------------------------------------------------------------------===//
959// CacheOp
960//===----------------------------------------------------------------------===//
961LogicalResult acc::CacheOp::verify() {
962 // Test for all clauses this operation can be decomposed from:
963 if (getDataClause() != acc::DataClause::acc_cache &&
964 getDataClause() != acc::DataClause::acc_cache_readonly)
965 return emitError(
966 "data clause associated with cache operation must match its intent"
967 " or specify original clause this operation was decomposed from");
968 if (failed(checkVarAndVarType(*this)))
969 return failure();
970 if (failed(checkVarAndAccVar(*this)))
971 return failure();
972 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
973 return failure();
974 return success();
975}
976
977bool acc::CacheOp::isCacheReadonly() {
978 return getDataClause() == acc::DataClause::acc_cache_readonly ||
979 acc::bitEnumContainsAny(getModifiers(),
980 acc::DataClauseModifier::readonly);
981}
982
983template <typename StructureOp>
984static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
985 unsigned nRegions = 1) {
986
988 for (unsigned i = 0; i < nRegions; ++i)
989 regions.push_back(state.addRegion());
990
991 for (Region *region : regions)
992 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
993 return failure();
994
995 return success();
996}
997
999 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1000}
1001
1002namespace {
1003/// Pattern to remove operation without region that have constant false `ifCond`
1004/// and remove the condition from the operation if the `ifCond` is a true
1005/// constant.
1006template <typename OpTy>
1007struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1008 using OpRewritePattern<OpTy>::OpRewritePattern;
1009
1010 LogicalResult matchAndRewrite(OpTy op,
1011 PatternRewriter &rewriter) const override {
1012 // Early return if there is no condition.
1013 Value ifCond = op.getIfCond();
1014 if (!ifCond)
1015 return failure();
1016
1017 IntegerAttr constAttr;
1018 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1019 return failure();
1020 if (constAttr.getInt())
1021 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1022 else
1023 rewriter.eraseOp(op);
1024
1025 return success();
1026 }
1027};
1028
1029/// Replaces the given op with the contents of the given single-block region,
1030/// using the operands of the block terminator to replace operation results.
1031static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1032 Region &region, ValueRange blockArgs = {}) {
1033 assert(region.hasOneBlock() && "expected single-block region");
1034 Block *block = &region.front();
1035 Operation *terminator = block->getTerminator();
1036 ValueRange results = terminator->getOperands();
1037 rewriter.inlineBlockBefore(block, op, blockArgs);
1038 rewriter.replaceOp(op, results);
1039 rewriter.eraseOp(terminator);
1040}
1041
1042/// Pattern to remove operation with region that have constant false `ifCond`
1043/// and remove the condition from the operation if the `ifCond` is constant
1044/// true.
1045template <typename OpTy>
1046struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1047 using OpRewritePattern<OpTy>::OpRewritePattern;
1048
1049 LogicalResult matchAndRewrite(OpTy op,
1050 PatternRewriter &rewriter) const override {
1051 // Early return if there is no condition.
1052 Value ifCond = op.getIfCond();
1053 if (!ifCond)
1054 return failure();
1055
1056 IntegerAttr constAttr;
1057 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1058 return failure();
1059 if (constAttr.getInt())
1060 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1061 else
1062 replaceOpWithRegion(rewriter, op, op.getRegion());
1063
1064 return success();
1065 }
1066};
1067
1068/// Remove empty acc.kernel_environment operations. If the operation has wait
1069/// operands, create a acc.wait operation to preserve synchronization.
1070struct RemoveEmptyKernelEnvironment
1071 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1072 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1073
1074 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1075 PatternRewriter &rewriter) const override {
1076 assert(op->getNumRegions() == 1 && "expected op to have one region");
1077
1078 Block &block = op.getRegion().front();
1079 if (!block.empty())
1080 return failure();
1081
1082 // Conservatively disable canonicalization of empty acc.kernel_environment
1083 // operations if the wait operands in the kernel_environment cannot be fully
1084 // represented by acc.wait operation.
1085
1086 // Disable canonicalization if device type is not the default
1087 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1088 for (auto attr : deviceTypeAttr) {
1089 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1090 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1091 return failure();
1092 }
1093 }
1094 }
1095
1096 // Disable canonicalization if any wait segment has a devnum
1097 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1098 for (auto attr : hasDevnumAttr) {
1099 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1100 if (boolAttr.getValue())
1101 return failure();
1102 }
1103 }
1104 }
1105
1106 // Disable canonicalization if there are multiple wait segments
1107 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1108 if (segmentsAttr.size() > 1)
1109 return failure();
1110 }
1111
1112 // Remove empty kernel environment.
1113 // Preserve synchronization by creating acc.wait operation if needed.
1114 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1115 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1116 /*asyncOperand=*/Value(),
1117 /*waitDevnum=*/Value(),
1118 /*async=*/nullptr,
1119 /*ifCond=*/Value());
1120 else
1121 rewriter.eraseOp(op);
1122
1123 return success();
1124 }
1125};
1126
1127//===----------------------------------------------------------------------===//
1128// Recipe Region Helpers
1129//===----------------------------------------------------------------------===//
1130
1131/// Create and populate an init region for privatization recipes.
1132/// Returns success if the region is populated, failure otherwise.
1133/// Sets needsFree to indicate if the allocated memory requires deallocation.
1134static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1135 Region &initRegion, Type varType,
1136 StringRef varName, ValueRange bounds,
1137 bool &needsFree) {
1138 // Create init block with arguments: original value + bounds
1139 SmallVector<Type> argTypes{varType};
1140 SmallVector<Location> argLocs{loc};
1141 for (Value bound : bounds) {
1142 argTypes.push_back(bound.getType());
1143 argLocs.push_back(loc);
1144 }
1145
1146 Block *initBlock = builder.createBlock(&initRegion);
1147 initBlock->addArguments(argTypes, argLocs);
1148 builder.setInsertionPointToStart(initBlock);
1149
1150 Value privatizedValue;
1151
1152 // Get the block argument that represents the original variable
1153 Value blockArgVar = initBlock->getArgument(0);
1154
1155 // Generate init region body based on variable type
1156 if (isa<MappableType>(varType)) {
1157 auto mappableTy = cast<MappableType>(varType);
1158 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1159 privatizedValue = mappableTy.generatePrivateInit(
1160 builder, loc, typedVar, varName, bounds, {}, needsFree);
1161 if (!privatizedValue)
1162 return failure();
1163 } else {
1164 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1165 auto pointerLikeTy = cast<PointerLikeType>(varType);
1166 // Use PointerLikeType's allocation API with the block argument
1167 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1168 blockArgVar, needsFree);
1169 if (!privatizedValue)
1170 return failure();
1171 }
1172
1173 // Add yield operation to init block
1174 acc::YieldOp::create(builder, loc, privatizedValue);
1175
1176 return success();
1177}
1178
1179/// Create and populate a copy region for firstprivate recipes.
1180/// Returns success if the region is populated, failure otherwise.
1181/// TODO: Handle MappableType - it does not yet have a copy API.
1182static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1183 Region &copyRegion, Type varType,
1184 ValueRange bounds) {
1185 // Create copy block with arguments: original value + privatized value +
1186 // bounds
1187 SmallVector<Type> copyArgTypes{varType, varType};
1188 SmallVector<Location> copyArgLocs{loc, loc};
1189 for (Value bound : bounds) {
1190 copyArgTypes.push_back(bound.getType());
1191 copyArgLocs.push_back(loc);
1192 }
1193
1194 Block *copyBlock = builder.createBlock(&copyRegion);
1195 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1196 builder.setInsertionPointToStart(copyBlock);
1197
1198 bool isMappable = isa<MappableType>(varType);
1199 bool isPointerLike = isa<PointerLikeType>(varType);
1200 // TODO: Handle MappableType - it does not yet have a copy API.
1201 // Otherwise, for now just fallback to pointer-like behavior.
1202 if (isMappable && !isPointerLike)
1203 return failure();
1204
1205 // Generate copy region body based on variable type
1206 if (isPointerLike) {
1207 auto pointerLikeTy = cast<PointerLikeType>(varType);
1208 Value originalArg = copyBlock->getArgument(0);
1209 Value privatizedArg = copyBlock->getArgument(1);
1210
1211 // Generate copy operation using PointerLikeType interface
1212 if (!pointerLikeTy.genCopy(
1213 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1214 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1215 return failure();
1216 }
1217
1218 // Add terminator to copy block
1219 acc::TerminatorOp::create(builder, loc);
1220
1221 return success();
1222}
1223
1224/// Create and populate a destroy region for privatization recipes.
1225/// Returns success if the region is populated, failure otherwise.
1226static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1227 Region &destroyRegion, Type varType,
1228 Value allocRes, ValueRange bounds) {
1229 // Create destroy block with arguments: original value + privatized value +
1230 // bounds
1231 SmallVector<Type> destroyArgTypes{varType, varType};
1232 SmallVector<Location> destroyArgLocs{loc, loc};
1233 for (Value bound : bounds) {
1234 destroyArgTypes.push_back(bound.getType());
1235 destroyArgLocs.push_back(loc);
1236 }
1237
1238 Block *destroyBlock = builder.createBlock(&destroyRegion);
1239 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1240 builder.setInsertionPointToStart(destroyBlock);
1241
1242 auto varToFree =
1243 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1244 if (isa<MappableType>(varType)) {
1245 auto mappableTy = cast<MappableType>(varType);
1246 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1247 return failure();
1248 } else {
1249 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1250 auto pointerLikeTy = cast<PointerLikeType>(varType);
1251 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1252 return failure();
1253 }
1254
1255 acc::TerminatorOp::create(builder, loc);
1256 return success();
1257}
1258
1259} // namespace
1260
1261//===----------------------------------------------------------------------===//
1262// PrivateRecipeOp
1263//===----------------------------------------------------------------------===//
1264
1266 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1267 Type type, bool verifyYield, bool optional = false) {
1268 if (optional && region.empty())
1269 return success();
1270
1271 if (region.empty())
1272 return op->emitOpError() << "expects non-empty " << regionName << " region";
1273 Block &firstBlock = region.front();
1274 if (firstBlock.getNumArguments() < 1 ||
1275 firstBlock.getArgument(0).getType() != type)
1276 return op->emitOpError() << "expects " << regionName
1277 << " region first "
1278 "argument of the "
1279 << regionType << " type";
1280
1281 if (verifyYield) {
1282 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1283 if (yieldOp.getOperands().size() != 1 ||
1284 yieldOp.getOperands().getTypes()[0] != type)
1285 return op->emitOpError() << "expects " << regionName
1286 << " region to "
1287 "yield a value of the "
1288 << regionType << " type";
1289 }
1290 }
1291 return success();
1292}
1293
1294LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1295 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1296 "privatization", "init", getType(),
1297 /*verifyYield=*/false)))
1298 return failure();
1300 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1301 /*verifyYield=*/false, /*optional=*/true)))
1302 return failure();
1303 return success();
1304}
1305
1306std::optional<PrivateRecipeOp>
1307PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1308 StringRef recipeName, Type varType,
1309 StringRef varName, ValueRange bounds) {
1310 // First, validate that we can handle this variable type
1311 bool isMappable = isa<MappableType>(varType);
1312 bool isPointerLike = isa<PointerLikeType>(varType);
1313
1314 // Unsupported type
1315 if (!isMappable && !isPointerLike)
1316 return std::nullopt;
1317
1318 OpBuilder::InsertionGuard guard(builder);
1319
1320 // Create the recipe operation first so regions have proper parent context
1321 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1322
1323 // Populate the init region
1324 bool needsFree = false;
1325 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1326 varName, bounds, needsFree))) {
1327 recipe.erase();
1328 return std::nullopt;
1329 }
1330
1331 // Only create destroy region if the allocation needs deallocation
1332 if (needsFree) {
1333 // Extract the allocated value from the init block's yield operation
1334 auto yieldOp =
1335 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1336 Value allocRes = yieldOp.getOperand(0);
1337
1338 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1339 varType, allocRes, bounds))) {
1340 recipe.erase();
1341 return std::nullopt;
1342 }
1343 }
1344
1345 return recipe;
1346}
1347
1348//===----------------------------------------------------------------------===//
1349// FirstprivateRecipeOp
1350//===----------------------------------------------------------------------===//
1351
1352LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1353 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1354 "privatization", "init", getType(),
1355 /*verifyYield=*/false)))
1356 return failure();
1357
1358 if (getCopyRegion().empty())
1359 return emitOpError() << "expects non-empty copy region";
1360
1361 Block &firstBlock = getCopyRegion().front();
1362 if (firstBlock.getNumArguments() < 2 ||
1363 firstBlock.getArgument(0).getType() != getType())
1364 return emitOpError() << "expects copy region with two arguments of the "
1365 "privatization type";
1366
1367 if (getDestroyRegion().empty())
1368 return success();
1369
1370 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1371 "privatization", "destroy",
1372 getType(), /*verifyYield=*/false)))
1373 return failure();
1374
1375 return success();
1376}
1377
1378std::optional<FirstprivateRecipeOp>
1379FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1380 StringRef recipeName, Type varType,
1381 StringRef varName, ValueRange bounds) {
1382 // First, validate that we can handle this variable type
1383 bool isMappable = isa<MappableType>(varType);
1384 bool isPointerLike = isa<PointerLikeType>(varType);
1385
1386 // Unsupported type
1387 if (!isMappable && !isPointerLike)
1388 return std::nullopt;
1389
1390 OpBuilder::InsertionGuard guard(builder);
1391
1392 // Create the recipe operation first so regions have proper parent context
1393 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1394
1395 // Populate the init region
1396 bool needsFree = false;
1397 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1398 varName, bounds, needsFree))) {
1399 recipe.erase();
1400 return std::nullopt;
1401 }
1402
1403 // Populate the copy region
1404 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1405 bounds))) {
1406 recipe.erase();
1407 return std::nullopt;
1408 }
1409
1410 // Only create destroy region if the allocation needs deallocation
1411 if (needsFree) {
1412 // Extract the allocated value from the init block's yield operation
1413 auto yieldOp =
1414 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1415 Value allocRes = yieldOp.getOperand(0);
1416
1417 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1418 varType, allocRes, bounds))) {
1419 recipe.erase();
1420 return std::nullopt;
1421 }
1422 }
1423
1424 return recipe;
1425}
1426
1427//===----------------------------------------------------------------------===//
1428// ReductionRecipeOp
1429//===----------------------------------------------------------------------===//
1430
1431LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1432 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1433 "init", getType(),
1434 /*verifyYield=*/false)))
1435 return failure();
1436
1437 if (getCombinerRegion().empty())
1438 return emitOpError() << "expects non-empty combiner region";
1439
1440 Block &reductionBlock = getCombinerRegion().front();
1441 if (reductionBlock.getNumArguments() < 2 ||
1442 reductionBlock.getArgument(0).getType() != getType() ||
1443 reductionBlock.getArgument(1).getType() != getType())
1444 return emitOpError() << "expects combiner region with the first two "
1445 << "arguments of the reduction type";
1446
1447 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1448 if (yieldOp.getOperands().size() != 1 ||
1449 yieldOp.getOperands().getTypes()[0] != getType())
1450 return emitOpError() << "expects combiner region to yield a value "
1451 "of the reduction type";
1452 }
1453
1454 return success();
1455}
1456
1457//===----------------------------------------------------------------------===//
1458// Custom parser and printer verifier for private clause
1459//===----------------------------------------------------------------------===//
1460
1461static ParseResult parseSymOperandList(
1462 mlir::OpAsmParser &parser,
1464 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
1466 if (failed(parser.parseCommaSeparatedList([&]() {
1467 if (parser.parseAttribute(attributes.emplace_back()) ||
1468 parser.parseArrow() ||
1469 parser.parseOperand(operands.emplace_back()) ||
1470 parser.parseColonType(types.emplace_back()))
1471 return failure();
1472 return success();
1473 })))
1474 return failure();
1475 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1476 attributes.end());
1477 symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
1478 return success();
1479}
1480
1482 mlir::OperandRange operands,
1483 mlir::TypeRange types,
1484 std::optional<mlir::ArrayAttr> attributes) {
1485 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
1486 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
1487 << std::get<1>(it).getType();
1488 });
1489}
1490
1491//===----------------------------------------------------------------------===//
1492// ParallelOp
1493//===----------------------------------------------------------------------===//
1494
1495/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1496template <typename Op>
1497static LogicalResult checkDataOperands(Op op,
1498 const mlir::ValueRange &operands) {
1499 for (mlir::Value operand : operands)
1500 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1501 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1502 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1503 operand.getDefiningOp()))
1504 return op.emitError(
1505 "expect data entry/exit operation or acc.getdeviceptr "
1506 "as defining op");
1507 return success();
1508}
1509
1510template <typename Op>
1511static LogicalResult
1512checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
1513 mlir::OperandRange operands, llvm::StringRef operandName,
1514 llvm::StringRef symbolName, bool checkOperandType = true) {
1515 if (!operands.empty()) {
1516 if (!attributes || attributes->size() != operands.size())
1517 return op->emitOpError()
1518 << "expected as many " << symbolName << " symbol reference as "
1519 << operandName << " operands";
1520 } else {
1521 if (attributes)
1522 return op->emitOpError()
1523 << "unexpected " << symbolName << " symbol reference";
1524 return success();
1525 }
1526
1528 for (auto args : llvm::zip(operands, *attributes)) {
1529 mlir::Value operand = std::get<0>(args);
1530
1531 if (!set.insert(operand).second)
1532 return op->emitOpError()
1533 << operandName << " operand appears more than once";
1534
1535 mlir::Type varType = operand.getType();
1536 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1537 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1538 if (!decl)
1539 return op->emitOpError()
1540 << "expected symbol reference " << symbolRef << " to point to a "
1541 << operandName << " declaration";
1542
1543 if (checkOperandType && decl.getType() && decl.getType() != varType)
1544 return op->emitOpError() << "expected " << operandName << " (" << varType
1545 << ") to be the same type as " << operandName
1546 << " declaration (" << decl.getType() << ")";
1547 }
1548
1549 return success();
1550}
1551
1552unsigned ParallelOp::getNumDataOperands() {
1553 return getReductionOperands().size() + getPrivateOperands().size() +
1554 getFirstprivateOperands().size() + getDataClauseOperands().size();
1555}
1556
1557Value ParallelOp::getDataOperand(unsigned i) {
1558 unsigned numOptional = getAsyncOperands().size();
1559 numOptional += getNumGangs().size();
1560 numOptional += getNumWorkers().size();
1561 numOptional += getVectorLength().size();
1562 numOptional += getIfCond() ? 1 : 0;
1563 numOptional += getSelfCond() ? 1 : 0;
1564 return getOperand(getWaitOperands().size() + numOptional + i);
1565}
1566
1567template <typename Op>
1568static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1569 ArrayAttr deviceTypes,
1570 llvm::StringRef keyword) {
1571 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1572 return op.emitOpError() << keyword << " operands count must match "
1573 << keyword << " device_type count";
1574 return success();
1575}
1576
1577template <typename Op>
1579 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1580 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1581 std::size_t numOperandsInSegments = 0;
1582 std::size_t nbOfSegments = 0;
1583
1584 if (segments) {
1585 for (auto segCount : segments.asArrayRef()) {
1586 if (maxInSegment != 0 && segCount > maxInSegment)
1587 return op.emitOpError() << keyword << " expects a maximum of "
1588 << maxInSegment << " values per segment";
1589 numOperandsInSegments += segCount;
1590 ++nbOfSegments;
1591 }
1592 }
1593
1594 if ((numOperandsInSegments != operands.size()) ||
1595 (!deviceTypes && !operands.empty()))
1596 return op.emitOpError()
1597 << keyword << " operand count does not match count in segments";
1598 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1599 return op.emitOpError()
1600 << keyword << " segment count does not match device_type count";
1601 return success();
1602}
1603
1604LogicalResult acc::ParallelOp::verify() {
1606 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1607 "privatizations", /*checkOperandType=*/false)))
1608 return failure();
1610 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1611 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1612 return failure();
1614 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1615 "reductions", false)))
1616 return failure();
1617
1619 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1620 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1621 return failure();
1622
1624 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1625 getWaitOperandsDeviceTypeAttr(), "wait")))
1626 return failure();
1627
1628 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1629 getNumWorkersDeviceTypeAttr(),
1630 "num_workers")))
1631 return failure();
1632
1633 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1634 getVectorLengthDeviceTypeAttr(),
1635 "vector_length")))
1636 return failure();
1637
1639 getAsyncOperandsDeviceTypeAttr(),
1640 "async")))
1641 return failure();
1642
1644 return failure();
1645
1646 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1647}
1648
1649static mlir::Value
1650getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1652 mlir::acc::DeviceType deviceType) {
1653 if (!arrayAttr)
1654 return {};
1655 if (auto pos = findSegment(*arrayAttr, deviceType))
1656 return range[*pos];
1657 return {};
1658}
1659
1660bool acc::ParallelOp::hasAsyncOnly() {
1661 return hasAsyncOnly(mlir::acc::DeviceType::None);
1662}
1663
1664bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1665 return hasDeviceType(getAsyncOnly(), deviceType);
1666}
1667
1668mlir::Value acc::ParallelOp::getAsyncValue() {
1669 return getAsyncValue(mlir::acc::DeviceType::None);
1670}
1671
1672mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1674 getAsyncOperands(), deviceType);
1675}
1676
1677mlir::Value acc::ParallelOp::getNumWorkersValue() {
1678 return getNumWorkersValue(mlir::acc::DeviceType::None);
1679}
1680
1682acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1683 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1684 deviceType);
1685}
1686
1687mlir::Value acc::ParallelOp::getVectorLengthValue() {
1688 return getVectorLengthValue(mlir::acc::DeviceType::None);
1689}
1690
1692acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1693 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1694 getVectorLength(), deviceType);
1695}
1696
1697mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1698 return getNumGangsValues(mlir::acc::DeviceType::None);
1699}
1700
1702ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1703 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1704 getNumGangsSegments(), deviceType);
1705}
1706
1707bool acc::ParallelOp::hasWaitOnly() {
1708 return hasWaitOnly(mlir::acc::DeviceType::None);
1709}
1710
1711bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1712 return hasDeviceType(getWaitOnly(), deviceType);
1713}
1714
1715mlir::Operation::operand_range ParallelOp::getWaitValues() {
1716 return getWaitValues(mlir::acc::DeviceType::None);
1717}
1718
1720ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1722 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1723 getHasWaitDevnum(), deviceType);
1724}
1725
1726mlir::Value ParallelOp::getWaitDevnum() {
1727 return getWaitDevnum(mlir::acc::DeviceType::None);
1728}
1729
1730mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1731 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1732 getWaitOperandsSegments(), getHasWaitDevnum(),
1733 deviceType);
1734}
1735
1736void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1737 mlir::OperationState &odsState,
1738 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1739 mlir::ValueRange vectorLength,
1740 mlir::ValueRange asyncOperands,
1741 mlir::ValueRange waitOperands, mlir::Value ifCond,
1742 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1743 mlir::ValueRange gangPrivateOperands,
1744 mlir::ValueRange gangFirstPrivateOperands,
1745 mlir::ValueRange dataClauseOperands) {
1746
1747 ParallelOp::build(
1748 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1749 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1750 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1751 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1752 /*numGangsDeviceType=*/nullptr, numWorkers,
1753 /*numWorkersDeviceType=*/nullptr, vectorLength,
1754 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1755 /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1756 gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1757 /*firstprivatizations=*/nullptr, dataClauseOperands,
1758 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1759}
1760
1761void acc::ParallelOp::addNumWorkersOperand(
1762 MLIRContext *context, mlir::Value newValue,
1763 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1764 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1765 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1766 getNumWorkersMutable()));
1767}
1768void acc::ParallelOp::addVectorLengthOperand(
1769 MLIRContext *context, mlir::Value newValue,
1770 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1771 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1772 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1773 getVectorLengthMutable()));
1774}
1775
1776void acc::ParallelOp::addAsyncOnly(
1777 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1778 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1779 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1780}
1781
1782void acc::ParallelOp::addAsyncOperand(
1783 MLIRContext *context, mlir::Value newValue,
1784 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1785 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1786 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1787 getAsyncOperandsMutable()));
1788}
1789
1790void acc::ParallelOp::addNumGangsOperands(
1791 MLIRContext *context, mlir::ValueRange newValues,
1792 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1794 if (getNumGangsSegments())
1795 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1796
1797 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1798 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1799 getNumGangsMutable(), segments));
1800
1801 setNumGangsSegments(segments);
1802}
1803void acc::ParallelOp::addWaitOnly(
1804 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1805 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1806 effectiveDeviceTypes));
1807}
1808void acc::ParallelOp::addWaitOperands(
1809 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1810 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1811
1813 if (getWaitOperandsSegments())
1814 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1815
1816 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1817 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1818 getWaitOperandsMutable(), segments));
1819 setWaitOperandsSegments(segments);
1820
1822 if (getHasWaitDevnumAttr())
1823 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1824 hasDevnums.insert(
1825 hasDevnums.end(),
1826 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1827 mlir::BoolAttr::get(context, hasDevnum));
1828 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1829}
1830
1831void acc::ParallelOp::addPrivatization(MLIRContext *context,
1832 mlir::acc::PrivateOp op,
1833 mlir::acc::PrivateRecipeOp recipe) {
1834 getPrivateOperandsMutable().append(op.getResult());
1835
1837
1838 if (getPrivatizationRecipesAttr())
1839 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1840
1841 recipes.push_back(
1842 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1843 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1844}
1845
1846void acc::ParallelOp::addFirstPrivatization(
1847 MLIRContext *context, mlir::acc::FirstprivateOp op,
1848 mlir::acc::FirstprivateRecipeOp recipe) {
1849 getFirstprivateOperandsMutable().append(op.getResult());
1850
1852
1853 if (getFirstprivatizationRecipesAttr())
1854 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1855
1856 recipes.push_back(
1857 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1858 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1859}
1860
1861void acc::ParallelOp::addReduction(MLIRContext *context,
1862 mlir::acc::ReductionOp op,
1863 mlir::acc::ReductionRecipeOp recipe) {
1864 getReductionOperandsMutable().append(op.getResult());
1865
1867
1868 if (getReductionRecipesAttr())
1869 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1870
1871 recipes.push_back(
1872 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1873 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1874}
1875
1876static ParseResult parseNumGangs(
1877 mlir::OpAsmParser &parser,
1879 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1880 mlir::DenseI32ArrayAttr &segments) {
1883
1884 do {
1885 if (failed(parser.parseLBrace()))
1886 return failure();
1887
1888 int32_t crtOperandsSize = operands.size();
1889 if (failed(parser.parseCommaSeparatedList(
1891 if (parser.parseOperand(operands.emplace_back()) ||
1892 parser.parseColonType(types.emplace_back()))
1893 return failure();
1894 return success();
1895 })))
1896 return failure();
1897 seg.push_back(operands.size() - crtOperandsSize);
1898
1899 if (failed(parser.parseRBrace()))
1900 return failure();
1901
1902 if (succeeded(parser.parseOptionalLSquare())) {
1903 if (parser.parseAttribute(attributes.emplace_back()) ||
1904 parser.parseRSquare())
1905 return failure();
1906 } else {
1907 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1908 parser.getContext(), mlir::acc::DeviceType::None));
1909 }
1910 } while (succeeded(parser.parseOptionalComma()));
1911
1912 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1913 attributes.end());
1914 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1915 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1916
1917 return success();
1918}
1919
1921 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1922 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1923 p << " [" << attr << "]";
1924}
1925
1927 mlir::OperandRange operands, mlir::TypeRange types,
1928 std::optional<mlir::ArrayAttr> deviceTypes,
1929 std::optional<mlir::DenseI32ArrayAttr> segments) {
1930 unsigned opIdx = 0;
1931 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1932 p << "{";
1933 llvm::interleaveComma(
1934 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1935 p << operands[opIdx] << " : " << operands[opIdx].getType();
1936 ++opIdx;
1937 });
1938 p << "}";
1939 printSingleDeviceType(p, it.value());
1940 });
1941}
1942
1944 mlir::OpAsmParser &parser,
1946 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1947 mlir::DenseI32ArrayAttr &segments) {
1950
1951 do {
1952 if (failed(parser.parseLBrace()))
1953 return failure();
1954
1955 int32_t crtOperandsSize = operands.size();
1956
1957 if (failed(parser.parseCommaSeparatedList(
1959 if (parser.parseOperand(operands.emplace_back()) ||
1960 parser.parseColonType(types.emplace_back()))
1961 return failure();
1962 return success();
1963 })))
1964 return failure();
1965
1966 seg.push_back(operands.size() - crtOperandsSize);
1967
1968 if (failed(parser.parseRBrace()))
1969 return failure();
1970
1971 if (succeeded(parser.parseOptionalLSquare())) {
1972 if (parser.parseAttribute(attributes.emplace_back()) ||
1973 parser.parseRSquare())
1974 return failure();
1975 } else {
1976 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1977 parser.getContext(), mlir::acc::DeviceType::None));
1978 }
1979 } while (succeeded(parser.parseOptionalComma()));
1980
1981 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1982 attributes.end());
1983 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1984 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1985
1986 return success();
1987}
1988
1991 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1992 std::optional<mlir::DenseI32ArrayAttr> segments) {
1993 unsigned opIdx = 0;
1994 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1995 p << "{";
1996 llvm::interleaveComma(
1997 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1998 p << operands[opIdx] << " : " << operands[opIdx].getType();
1999 ++opIdx;
2000 });
2001 p << "}";
2002 printSingleDeviceType(p, it.value());
2003 });
2004}
2005
2006static ParseResult parseWaitClause(
2007 mlir::OpAsmParser &parser,
2009 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2010 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2011 mlir::ArrayAttr &keywordOnly) {
2012 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2014
2015 bool needCommaBeforeOperands = false;
2016
2017 // Keyword only
2018 if (failed(parser.parseOptionalLParen())) {
2019 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2020 parser.getContext(), mlir::acc::DeviceType::None));
2021 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2022 return success();
2023 }
2024
2025 // Parse keyword only attributes
2026 if (succeeded(parser.parseOptionalLSquare())) {
2027 if (failed(parser.parseCommaSeparatedList([&]() {
2028 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2029 return failure();
2030 return success();
2031 })))
2032 return failure();
2033 if (parser.parseRSquare())
2034 return failure();
2035 needCommaBeforeOperands = true;
2036 }
2037
2038 if (needCommaBeforeOperands && failed(parser.parseComma()))
2039 return failure();
2040
2041 do {
2042 if (failed(parser.parseLBrace()))
2043 return failure();
2044
2045 int32_t crtOperandsSize = operands.size();
2046
2047 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2048 if (failed(parser.parseColon()))
2049 return failure();
2050 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2051 } else {
2052 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2053 }
2054
2055 if (failed(parser.parseCommaSeparatedList(
2057 if (parser.parseOperand(operands.emplace_back()) ||
2058 parser.parseColonType(types.emplace_back()))
2059 return failure();
2060 return success();
2061 })))
2062 return failure();
2063
2064 seg.push_back(operands.size() - crtOperandsSize);
2065
2066 if (failed(parser.parseRBrace()))
2067 return failure();
2068
2069 if (succeeded(parser.parseOptionalLSquare())) {
2070 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2071 parser.parseRSquare())
2072 return failure();
2073 } else {
2074 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2075 parser.getContext(), mlir::acc::DeviceType::None));
2076 }
2077 } while (succeeded(parser.parseOptionalComma()));
2078
2079 if (failed(parser.parseRParen()))
2080 return failure();
2081
2082 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2083 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2084 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2085 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2086
2087 return success();
2088}
2089
2090static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2091 if (!hasDeviceTypeValues(attrs))
2092 return false;
2093 if (attrs->size() != 1)
2094 return false;
2095 if (auto deviceTypeAttr =
2096 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2097 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2098 return false;
2099}
2100
2102 mlir::OperandRange operands, mlir::TypeRange types,
2103 std::optional<mlir::ArrayAttr> deviceTypes,
2104 std::optional<mlir::DenseI32ArrayAttr> segments,
2105 std::optional<mlir::ArrayAttr> hasDevNum,
2106 std::optional<mlir::ArrayAttr> keywordOnly) {
2107
2108 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2109 return;
2110
2111 p << "(";
2112
2113 printDeviceTypes(p, keywordOnly);
2114 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2115 p << ", ";
2116
2117 if (hasDeviceTypeValues(deviceTypes)) {
2118 unsigned opIdx = 0;
2119 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2120 p << "{";
2121 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2122 if (boolAttr && boolAttr.getValue())
2123 p << "devnum: ";
2124 llvm::interleaveComma(
2125 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2126 p << operands[opIdx] << " : " << operands[opIdx].getType();
2127 ++opIdx;
2128 });
2129 p << "}";
2130 printSingleDeviceType(p, it.value());
2131 });
2132 }
2133
2134 p << ")";
2135}
2136
2137static ParseResult parseDeviceTypeOperands(
2138 mlir::OpAsmParser &parser,
2140 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2142 if (failed(parser.parseCommaSeparatedList([&]() {
2143 if (parser.parseOperand(operands.emplace_back()) ||
2144 parser.parseColonType(types.emplace_back()))
2145 return failure();
2146 if (succeeded(parser.parseOptionalLSquare())) {
2147 if (parser.parseAttribute(attributes.emplace_back()) ||
2148 parser.parseRSquare())
2149 return failure();
2150 } else {
2151 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2152 parser.getContext(), mlir::acc::DeviceType::None));
2153 }
2154 return success();
2155 })))
2156 return failure();
2157 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2158 attributes.end());
2159 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2160 return success();
2161}
2162
2163static void
2165 mlir::OperandRange operands, mlir::TypeRange types,
2166 std::optional<mlir::ArrayAttr> deviceTypes) {
2167 if (!hasDeviceTypeValues(deviceTypes))
2168 return;
2169 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2170 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2171 printSingleDeviceType(p, std::get<0>(it));
2172 });
2173}
2174
2176 mlir::OpAsmParser &parser,
2178 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2179 mlir::ArrayAttr &keywordOnlyDeviceType) {
2180
2181 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2182 bool needCommaBeforeOperands = false;
2183
2184 if (failed(parser.parseOptionalLParen())) {
2185 // Keyword only
2186 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2187 parser.getContext(), mlir::acc::DeviceType::None));
2188 keywordOnlyDeviceType =
2189 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2190 return success();
2191 }
2192
2193 // Parse keyword only attributes
2194 if (succeeded(parser.parseOptionalLSquare())) {
2195 // Parse keyword only attributes
2196 if (failed(parser.parseCommaSeparatedList([&]() {
2197 if (parser.parseAttribute(
2198 keywordOnlyDeviceTypeAttributes.emplace_back()))
2199 return failure();
2200 return success();
2201 })))
2202 return failure();
2203 if (parser.parseRSquare())
2204 return failure();
2205 needCommaBeforeOperands = true;
2206 }
2207
2208 if (needCommaBeforeOperands && failed(parser.parseComma()))
2209 return failure();
2210
2212 if (failed(parser.parseCommaSeparatedList([&]() {
2213 if (parser.parseOperand(operands.emplace_back()) ||
2214 parser.parseColonType(types.emplace_back()))
2215 return failure();
2216 if (succeeded(parser.parseOptionalLSquare())) {
2217 if (parser.parseAttribute(attributes.emplace_back()) ||
2218 parser.parseRSquare())
2219 return failure();
2220 } else {
2221 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2222 parser.getContext(), mlir::acc::DeviceType::None));
2223 }
2224 return success();
2225 })))
2226 return failure();
2227
2228 if (failed(parser.parseRParen()))
2229 return failure();
2230
2231 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2232 attributes.end());
2233 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2234 return success();
2235}
2236
2239 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2240 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2241
2242 if (operands.begin() == operands.end() &&
2243 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2244 return;
2245 }
2246
2247 p << "(";
2248 printDeviceTypes(p, keywordOnlyDeviceTypes);
2249 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2250 hasDeviceTypeValues(deviceTypes))
2251 p << ", ";
2252 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2253 p << ")";
2254}
2255
2257 mlir::OpAsmParser &parser,
2258 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2259 mlir::Type &operandType, mlir::UnitAttr &attr) {
2260 // Keyword only
2261 if (failed(parser.parseOptionalLParen())) {
2262 attr = mlir::UnitAttr::get(parser.getContext());
2263 return success();
2264 }
2265
2267 if (failed(parser.parseOperand(op)))
2268 return failure();
2269 operand = op;
2270 if (failed(parser.parseColon()))
2271 return failure();
2272 if (failed(parser.parseType(operandType)))
2273 return failure();
2274 if (failed(parser.parseRParen()))
2275 return failure();
2276
2277 return success();
2278}
2279
2281 mlir::Operation *op,
2282 std::optional<mlir::Value> operand,
2283 mlir::Type operandType,
2284 mlir::UnitAttr attr) {
2285 if (attr)
2286 return;
2287
2288 p << "(";
2289 p.printOperand(*operand);
2290 p << " : ";
2291 p.printType(operandType);
2292 p << ")";
2293}
2294
2296 mlir::OpAsmParser &parser,
2298 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2299 // Keyword only
2300 if (failed(parser.parseOptionalLParen())) {
2301 attr = mlir::UnitAttr::get(parser.getContext());
2302 return success();
2303 }
2304
2305 if (failed(parser.parseCommaSeparatedList([&]() {
2306 if (parser.parseOperand(operands.emplace_back()))
2307 return failure();
2308 return success();
2309 })))
2310 return failure();
2311 if (failed(parser.parseColon()))
2312 return failure();
2313 if (failed(parser.parseCommaSeparatedList([&]() {
2314 if (parser.parseType(types.emplace_back()))
2315 return failure();
2316 return success();
2317 })))
2318 return failure();
2319 if (failed(parser.parseRParen()))
2320 return failure();
2321
2322 return success();
2323}
2324
2326 mlir::Operation *op,
2327 mlir::OperandRange operands,
2328 mlir::TypeRange types,
2329 mlir::UnitAttr attr) {
2330 if (attr)
2331 return;
2332
2333 p << "(";
2334 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2335 p << " : ";
2336 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2337 p << ")";
2338}
2339
2340static ParseResult
2342 mlir::acc::CombinedConstructsTypeAttr &attr) {
2343 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2344 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2345 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2346 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2347 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2348 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2349 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2350 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2351 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2352 } else {
2353 parser.emitError(parser.getCurrentLocation(),
2354 "expected compute construct name");
2355 return failure();
2356 }
2357 return success();
2358}
2359
2360static void
2362 mlir::acc::CombinedConstructsTypeAttr attr) {
2363 if (attr) {
2364 switch (attr.getValue()) {
2365 case mlir::acc::CombinedConstructsType::KernelsLoop:
2366 p << "kernels";
2367 break;
2368 case mlir::acc::CombinedConstructsType::ParallelLoop:
2369 p << "parallel";
2370 break;
2371 case mlir::acc::CombinedConstructsType::SerialLoop:
2372 p << "serial";
2373 break;
2374 };
2375 }
2376}
2377
2378//===----------------------------------------------------------------------===//
2379// SerialOp
2380//===----------------------------------------------------------------------===//
2381
2382unsigned SerialOp::getNumDataOperands() {
2383 return getReductionOperands().size() + getPrivateOperands().size() +
2384 getFirstprivateOperands().size() + getDataClauseOperands().size();
2385}
2386
2387Value SerialOp::getDataOperand(unsigned i) {
2388 unsigned numOptional = getAsyncOperands().size();
2389 numOptional += getIfCond() ? 1 : 0;
2390 numOptional += getSelfCond() ? 1 : 0;
2391 return getOperand(getWaitOperands().size() + numOptional + i);
2392}
2393
2394bool acc::SerialOp::hasAsyncOnly() {
2395 return hasAsyncOnly(mlir::acc::DeviceType::None);
2396}
2397
2398bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2399 return hasDeviceType(getAsyncOnly(), deviceType);
2400}
2401
2402mlir::Value acc::SerialOp::getAsyncValue() {
2403 return getAsyncValue(mlir::acc::DeviceType::None);
2404}
2405
2406mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2408 getAsyncOperands(), deviceType);
2409}
2410
2411bool acc::SerialOp::hasWaitOnly() {
2412 return hasWaitOnly(mlir::acc::DeviceType::None);
2413}
2414
2415bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2416 return hasDeviceType(getWaitOnly(), deviceType);
2417}
2418
2419mlir::Operation::operand_range SerialOp::getWaitValues() {
2420 return getWaitValues(mlir::acc::DeviceType::None);
2421}
2422
2424SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2426 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2427 getHasWaitDevnum(), deviceType);
2428}
2429
2430mlir::Value SerialOp::getWaitDevnum() {
2431 return getWaitDevnum(mlir::acc::DeviceType::None);
2432}
2433
2434mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2435 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2436 getWaitOperandsSegments(), getHasWaitDevnum(),
2437 deviceType);
2438}
2439
2440LogicalResult acc::SerialOp::verify() {
2442 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2443 "privatizations", /*checkOperandType=*/false)))
2444 return failure();
2446 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2447 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
2448 return failure();
2450 *this, getReductionRecipes(), getReductionOperands(), "reduction",
2451 "reductions", false)))
2452 return failure();
2453
2455 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2456 getWaitOperandsDeviceTypeAttr(), "wait")))
2457 return failure();
2458
2460 getAsyncOperandsDeviceTypeAttr(),
2461 "async")))
2462 return failure();
2463
2465 return failure();
2466
2467 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2468}
2469
2470void acc::SerialOp::addAsyncOnly(
2471 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2472 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2473 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2474}
2475
2476void acc::SerialOp::addAsyncOperand(
2477 MLIRContext *context, mlir::Value newValue,
2478 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2479 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2480 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2481 getAsyncOperandsMutable()));
2482}
2483
2484void acc::SerialOp::addWaitOnly(
2485 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2486 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2487 effectiveDeviceTypes));
2488}
2489void acc::SerialOp::addWaitOperands(
2490 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2491 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2492
2494 if (getWaitOperandsSegments())
2495 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2496
2497 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2498 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2499 getWaitOperandsMutable(), segments));
2500 setWaitOperandsSegments(segments);
2501
2503 if (getHasWaitDevnumAttr())
2504 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2505 hasDevnums.insert(
2506 hasDevnums.end(),
2507 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2508 mlir::BoolAttr::get(context, hasDevnum));
2509 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2510}
2511
2512void acc::SerialOp::addPrivatization(MLIRContext *context,
2513 mlir::acc::PrivateOp op,
2514 mlir::acc::PrivateRecipeOp recipe) {
2515 getPrivateOperandsMutable().append(op.getResult());
2516
2518
2519 if (getPrivatizationRecipesAttr())
2520 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2521
2522 recipes.push_back(
2523 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2524 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2525}
2526
2527void acc::SerialOp::addFirstPrivatization(
2528 MLIRContext *context, mlir::acc::FirstprivateOp op,
2529 mlir::acc::FirstprivateRecipeOp recipe) {
2530 getFirstprivateOperandsMutable().append(op.getResult());
2531
2533
2534 if (getFirstprivatizationRecipesAttr())
2535 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2536
2537 recipes.push_back(
2538 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2539 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2540}
2541
2542void acc::SerialOp::addReduction(MLIRContext *context,
2543 mlir::acc::ReductionOp op,
2544 mlir::acc::ReductionRecipeOp recipe) {
2545 getReductionOperandsMutable().append(op.getResult());
2546
2548
2549 if (getReductionRecipesAttr())
2550 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2551
2552 recipes.push_back(
2553 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2554 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2555}
2556
2557//===----------------------------------------------------------------------===//
2558// KernelsOp
2559//===----------------------------------------------------------------------===//
2560
2561unsigned KernelsOp::getNumDataOperands() {
2562 return getDataClauseOperands().size();
2563}
2564
2565Value KernelsOp::getDataOperand(unsigned i) {
2566 unsigned numOptional = getAsyncOperands().size();
2567 numOptional += getWaitOperands().size();
2568 numOptional += getNumGangs().size();
2569 numOptional += getNumWorkers().size();
2570 numOptional += getVectorLength().size();
2571 numOptional += getIfCond() ? 1 : 0;
2572 numOptional += getSelfCond() ? 1 : 0;
2573 return getOperand(numOptional + i);
2574}
2575
2576bool acc::KernelsOp::hasAsyncOnly() {
2577 return hasAsyncOnly(mlir::acc::DeviceType::None);
2578}
2579
2580bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2581 return hasDeviceType(getAsyncOnly(), deviceType);
2582}
2583
2584mlir::Value acc::KernelsOp::getAsyncValue() {
2585 return getAsyncValue(mlir::acc::DeviceType::None);
2586}
2587
2588mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2590 getAsyncOperands(), deviceType);
2591}
2592
2593mlir::Value acc::KernelsOp::getNumWorkersValue() {
2594 return getNumWorkersValue(mlir::acc::DeviceType::None);
2595}
2596
2598acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2599 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2600 deviceType);
2601}
2602
2603mlir::Value acc::KernelsOp::getVectorLengthValue() {
2604 return getVectorLengthValue(mlir::acc::DeviceType::None);
2605}
2606
2608acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2609 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2610 getVectorLength(), deviceType);
2611}
2612
2613mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2614 return getNumGangsValues(mlir::acc::DeviceType::None);
2615}
2616
2618KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2619 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2620 getNumGangsSegments(), deviceType);
2621}
2622
2623bool acc::KernelsOp::hasWaitOnly() {
2624 return hasWaitOnly(mlir::acc::DeviceType::None);
2625}
2626
2627bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2628 return hasDeviceType(getWaitOnly(), deviceType);
2629}
2630
2631mlir::Operation::operand_range KernelsOp::getWaitValues() {
2632 return getWaitValues(mlir::acc::DeviceType::None);
2633}
2634
2636KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2638 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2639 getHasWaitDevnum(), deviceType);
2640}
2641
2642mlir::Value KernelsOp::getWaitDevnum() {
2643 return getWaitDevnum(mlir::acc::DeviceType::None);
2644}
2645
2646mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2647 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2648 getWaitOperandsSegments(), getHasWaitDevnum(),
2649 deviceType);
2650}
2651
2652LogicalResult acc::KernelsOp::verify() {
2654 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2655 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2656 return failure();
2657
2659 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2660 getWaitOperandsDeviceTypeAttr(), "wait")))
2661 return failure();
2662
2663 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2664 getNumWorkersDeviceTypeAttr(),
2665 "num_workers")))
2666 return failure();
2667
2668 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2669 getVectorLengthDeviceTypeAttr(),
2670 "vector_length")))
2671 return failure();
2672
2674 getAsyncOperandsDeviceTypeAttr(),
2675 "async")))
2676 return failure();
2677
2679 return failure();
2680
2681 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2682}
2683
2684void acc::KernelsOp::addNumWorkersOperand(
2685 MLIRContext *context, mlir::Value newValue,
2686 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2687 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2688 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2689 getNumWorkersMutable()));
2690}
2691
2692void acc::KernelsOp::addVectorLengthOperand(
2693 MLIRContext *context, mlir::Value newValue,
2694 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2695 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2696 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2697 getVectorLengthMutable()));
2698}
2699void acc::KernelsOp::addAsyncOnly(
2700 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2701 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2702 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2703}
2704
2705void acc::KernelsOp::addAsyncOperand(
2706 MLIRContext *context, mlir::Value newValue,
2707 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2708 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2709 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2710 getAsyncOperandsMutable()));
2711}
2712
2713void acc::KernelsOp::addNumGangsOperands(
2714 MLIRContext *context, mlir::ValueRange newValues,
2715 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2717 if (getNumGangsSegmentsAttr())
2718 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2719
2720 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2721 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2722 getNumGangsMutable(), segments));
2723
2724 setNumGangsSegments(segments);
2725}
2726
2727void acc::KernelsOp::addWaitOnly(
2728 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2729 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2730 effectiveDeviceTypes));
2731}
2732void acc::KernelsOp::addWaitOperands(
2733 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2734 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2735
2737 if (getWaitOperandsSegments())
2738 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2739
2740 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2741 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2742 getWaitOperandsMutable(), segments));
2743 setWaitOperandsSegments(segments);
2744
2746 if (getHasWaitDevnumAttr())
2747 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2748 hasDevnums.insert(
2749 hasDevnums.end(),
2750 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2751 mlir::BoolAttr::get(context, hasDevnum));
2752 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2753}
2754
2755//===----------------------------------------------------------------------===//
2756// HostDataOp
2757//===----------------------------------------------------------------------===//
2758
2759LogicalResult acc::HostDataOp::verify() {
2760 if (getDataClauseOperands().empty())
2761 return emitError("at least one operand must appear on the host_data "
2762 "operation");
2763
2764 for (mlir::Value operand : getDataClauseOperands())
2765 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2766 return emitError("expect data entry operation as defining op");
2767 return success();
2768}
2769
2770void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2771 MLIRContext *context) {
2772 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2773}
2774
2775//===----------------------------------------------------------------------===//
2776// KernelEnvironmentOp
2777//===----------------------------------------------------------------------===//
2778
2779void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2780 RewritePatternSet &results, MLIRContext *context) {
2781 results.add<RemoveEmptyKernelEnvironment>(context);
2782}
2783
2784//===----------------------------------------------------------------------===//
2785// LoopOp
2786//===----------------------------------------------------------------------===//
2787
2788static ParseResult parseGangValue(
2789 OpAsmParser &parser, llvm::StringRef keyword,
2792 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2793 bool &needCommaBetweenValues, bool &newValue) {
2794 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2795 if (parser.parseEqual())
2796 return failure();
2797 if (parser.parseOperand(operands.emplace_back()) ||
2798 parser.parseColonType(types.emplace_back()))
2799 return failure();
2800 attributes.push_back(gangArgType);
2801 needCommaBetweenValues = true;
2802 newValue = true;
2803 }
2804 return success();
2805}
2806
2807static ParseResult parseGangClause(
2808 OpAsmParser &parser,
2810 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2811 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2812 mlir::ArrayAttr &gangOnlyDeviceType) {
2813 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2814 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2815 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2817 bool needCommaBetweenValues = false;
2818 bool needCommaBeforeOperands = false;
2819
2820 if (failed(parser.parseOptionalLParen())) {
2821 // Gang only keyword
2822 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2823 parser.getContext(), mlir::acc::DeviceType::None));
2824 gangOnlyDeviceType =
2825 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2826 return success();
2827 }
2828
2829 // Parse gang only attributes
2830 if (succeeded(parser.parseOptionalLSquare())) {
2831 // Parse gang only attributes
2832 if (failed(parser.parseCommaSeparatedList([&]() {
2833 if (parser.parseAttribute(
2834 gangOnlyDeviceTypeAttributes.emplace_back()))
2835 return failure();
2836 return success();
2837 })))
2838 return failure();
2839 if (parser.parseRSquare())
2840 return failure();
2841 needCommaBeforeOperands = true;
2842 }
2843
2844 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2845 mlir::acc::GangArgType::Num);
2846 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2847 mlir::acc::GangArgType::Dim);
2848 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2849 parser.getContext(), mlir::acc::GangArgType::Static);
2850
2851 do {
2852 if (needCommaBeforeOperands) {
2853 needCommaBeforeOperands = false;
2854 continue;
2855 }
2856
2857 if (failed(parser.parseLBrace()))
2858 return failure();
2859
2860 int32_t crtOperandsSize = gangOperands.size();
2861 while (true) {
2862 bool newValue = false;
2863 bool needValue = false;
2864 if (needCommaBetweenValues) {
2865 if (succeeded(parser.parseOptionalComma()))
2866 needValue = true; // expect a new value after comma.
2867 else
2868 break;
2869 }
2870
2871 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2872 gangOperands, gangOperandsType,
2873 gangArgTypeAttributes, argNum,
2874 needCommaBetweenValues, newValue)))
2875 return failure();
2876 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2877 gangOperands, gangOperandsType,
2878 gangArgTypeAttributes, argDim,
2879 needCommaBetweenValues, newValue)))
2880 return failure();
2881 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2882 gangOperands, gangOperandsType,
2883 gangArgTypeAttributes, argStatic,
2884 needCommaBetweenValues, newValue)))
2885 return failure();
2886
2887 if (!newValue && needValue) {
2888 parser.emitError(parser.getCurrentLocation(),
2889 "new value expected after comma");
2890 return failure();
2891 }
2892
2893 if (!newValue)
2894 break;
2895 }
2896
2897 if (gangOperands.empty())
2898 return parser.emitError(
2899 parser.getCurrentLocation(),
2900 "expect at least one of num, dim or static values");
2901
2902 if (failed(parser.parseRBrace()))
2903 return failure();
2904
2905 if (succeeded(parser.parseOptionalLSquare())) {
2906 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
2907 parser.parseRSquare())
2908 return failure();
2909 } else {
2910 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2911 parser.getContext(), mlir::acc::DeviceType::None));
2912 }
2913
2914 seg.push_back(gangOperands.size() - crtOperandsSize);
2915
2916 } while (succeeded(parser.parseOptionalComma()));
2917
2918 if (failed(parser.parseRParen()))
2919 return failure();
2920
2921 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2922 gangArgTypeAttributes.end());
2923 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2924 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2925
2927 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2928 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2929
2930 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2931 return success();
2932}
2933
2935 mlir::OperandRange operands, mlir::TypeRange types,
2936 std::optional<mlir::ArrayAttr> gangArgTypes,
2937 std::optional<mlir::ArrayAttr> deviceTypes,
2938 std::optional<mlir::DenseI32ArrayAttr> segments,
2939 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2940
2941 if (operands.begin() == operands.end() &&
2942 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
2943 return;
2944 }
2945
2946 p << "(";
2947
2948 printDeviceTypes(p, gangOnlyDeviceTypes);
2949
2950 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
2951 hasDeviceTypeValues(deviceTypes))
2952 p << ", ";
2953
2954 if (hasDeviceTypeValues(deviceTypes)) {
2955 unsigned opIdx = 0;
2956 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2957 p << "{";
2958 llvm::interleaveComma(
2959 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2960 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2961 (*gangArgTypes)[opIdx]);
2962 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2963 p << LoopOp::getGangNumKeyword();
2964 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2965 p << LoopOp::getGangDimKeyword();
2966 else if (gangArgTypeAttr.getValue() ==
2967 mlir::acc::GangArgType::Static)
2968 p << LoopOp::getGangStaticKeyword();
2969 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2970 ++opIdx;
2971 });
2972 p << "}";
2973 printSingleDeviceType(p, it.value());
2974 });
2975 }
2976 p << ")";
2977}
2978
2980 std::optional<mlir::ArrayAttr> segments,
2981 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2982 if (!segments)
2983 return false;
2984 for (auto attr : *segments) {
2985 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2986 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2987 return true;
2988 }
2989 return false;
2990}
2991
2992/// Check for duplicates in the DeviceType array attribute.
2993LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2994 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2995 if (!deviceTypes)
2996 return success();
2997 for (auto attr : deviceTypes) {
2998 auto deviceTypeAttr =
2999 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3000 if (!deviceTypeAttr)
3001 return failure();
3002 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3003 return failure();
3004 }
3005 return success();
3006}
3007
3008LogicalResult acc::LoopOp::verify() {
3009 if (getUpperbound().size() != getStep().size())
3010 return emitError() << "number of upperbounds expected to be the same as "
3011 "number of steps";
3012
3013 if (getUpperbound().size() != getLowerbound().size())
3014 return emitError() << "number of upperbounds expected to be the same as "
3015 "number of lowerbounds";
3016
3017 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3018 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3019 return emitError() << "inclusiveUpperbound size is expected to be the same"
3020 << " as upperbound size";
3021
3022 // Check collapse
3023 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3024 return emitOpError() << "collapse device_type attr must be define when"
3025 << " collapse attr is present";
3026
3027 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3028 getCollapseAttr().getValue().size() !=
3029 getCollapseDeviceTypeAttr().getValue().size())
3030 return emitOpError() << "collapse attribute count must match collapse"
3031 << " device_type count";
3032 if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
3033 return emitOpError()
3034 << "duplicate device_type found in collapseDeviceType attribute";
3035
3036 // Check gang
3037 if (!getGangOperands().empty()) {
3038 if (!getGangOperandsArgType())
3039 return emitOpError() << "gangOperandsArgType attribute must be defined"
3040 << " when gang operands are present";
3041
3042 if (getGangOperands().size() !=
3043 getGangOperandsArgTypeAttr().getValue().size())
3044 return emitOpError() << "gangOperandsArgType attribute count must match"
3045 << " gangOperands count";
3046 }
3047 if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
3048 return emitOpError() << "duplicate device_type found in gang attribute";
3049
3051 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3052 getGangOperandsDeviceTypeAttr(), "gang")))
3053 return failure();
3054
3055 // Check worker
3056 if (failed(checkDeviceTypes(getWorkerAttr())))
3057 return emitOpError() << "duplicate device_type found in worker attribute";
3058 if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
3059 return emitOpError() << "duplicate device_type found in "
3060 "workerNumOperandsDeviceType attribute";
3061 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3062 getWorkerNumOperandsDeviceTypeAttr(),
3063 "worker")))
3064 return failure();
3065
3066 // Check vector
3067 if (failed(checkDeviceTypes(getVectorAttr())))
3068 return emitOpError() << "duplicate device_type found in vector attribute";
3069 if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
3070 return emitOpError() << "duplicate device_type found in "
3071 "vectorOperandsDeviceType attribute";
3072 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3073 getVectorOperandsDeviceTypeAttr(),
3074 "vector")))
3075 return failure();
3076
3078 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3079 getTileOperandsDeviceTypeAttr(), "tile")))
3080 return failure();
3081
3082 // auto, independent and seq attribute are mutually exclusive.
3083 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3084 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3085 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3086 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3087 return emitError() << "only one of auto, independent, seq can be present "
3088 "at the same time";
3089 }
3090
3091 // Check that at least one of auto, independent, or seq is present
3092 // for the device-independent default clauses.
3093 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3094 return attr.getValue() == mlir::acc::DeviceType::None;
3095 };
3096 bool hasDefaultSeq =
3097 getSeqAttr()
3098 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3099 hasDeviceNone)
3100 : false;
3101 bool hasDefaultIndependent =
3102 getIndependentAttr()
3103 ? llvm::any_of(
3104 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3105 hasDeviceNone)
3106 : false;
3107 bool hasDefaultAuto =
3108 getAuto_Attr()
3109 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3110 hasDeviceNone)
3111 : false;
3112 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3113 return emitError()
3114 << "at least one of auto, independent, seq must be present";
3115 }
3116
3117 // Gang, worker and vector are incompatible with seq.
3118 if (getSeqAttr()) {
3119 for (auto attr : getSeqAttr()) {
3120 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3121 if (hasVector(deviceTypeAttr.getValue()) ||
3122 getVectorValue(deviceTypeAttr.getValue()) ||
3123 hasWorker(deviceTypeAttr.getValue()) ||
3124 getWorkerValue(deviceTypeAttr.getValue()) ||
3125 hasGang(deviceTypeAttr.getValue()) ||
3126 getGangValue(mlir::acc::GangArgType::Num,
3127 deviceTypeAttr.getValue()) ||
3128 getGangValue(mlir::acc::GangArgType::Dim,
3129 deviceTypeAttr.getValue()) ||
3130 getGangValue(mlir::acc::GangArgType::Static,
3131 deviceTypeAttr.getValue()))
3132 return emitError() << "gang, worker or vector cannot appear with seq";
3133 }
3134 }
3135
3137 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
3138 "privatizations", false)))
3139 return failure();
3140
3142 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3143 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
3144 return failure();
3145
3147 *this, getReductionRecipes(), getReductionOperands(), "reduction",
3148 "reductions", false)))
3149 return failure();
3150
3151 if (getCombined().has_value() &&
3152 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3153 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3154 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3155 return emitError("unexpected combined constructs attribute");
3156 }
3157
3158 // Check non-empty body().
3159 if (getRegion().empty())
3160 return emitError("expected non-empty body.");
3161
3162 if (getUnstructured()) {
3163 if (!isContainerLike())
3164 return emitError(
3165 "unstructured acc.loop must not have induction variables");
3166 } else if (isContainerLike()) {
3167 // When it is container-like - it is expected to hold a loop-like operation.
3168 // Obtain the maximum collapse count - we use this to check that there
3169 // are enough loops contained.
3170 uint64_t collapseCount = getCollapseValue().value_or(1);
3171 if (getCollapseAttr()) {
3172 for (auto collapseEntry : getCollapseAttr()) {
3173 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3174 if (intAttr.getValue().getZExtValue() > collapseCount)
3175 collapseCount = intAttr.getValue().getZExtValue();
3176 }
3177 }
3178
3179 // We want to check that we find enough loop-like operations inside.
3180 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3181 // level.
3182 mlir::Operation *expectedParent = this->getOperation();
3183 bool foundSibling = false;
3184 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3185 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3186 // This effectively checks that we are not looking at a sibling loop.
3187 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3188 expectedParent) {
3189 foundSibling = true;
3191 }
3192
3193 collapseCount--;
3194 expectedParent = op;
3195 }
3196 // We found enough contained loops.
3197 if (collapseCount == 0)
3200 });
3201
3202 if (foundSibling)
3203 return emitError("found sibling loops inside container-like acc.loop");
3204 if (collapseCount != 0)
3205 return emitError("failed to find enough loop-like operations inside "
3206 "container-like acc.loop");
3207 }
3208
3209 return success();
3210}
3211
3212unsigned LoopOp::getNumDataOperands() {
3213 return getReductionOperands().size() + getPrivateOperands().size() +
3214 getFirstprivateOperands().size();
3215}
3216
3217Value LoopOp::getDataOperand(unsigned i) {
3218 unsigned numOptional =
3219 getLowerbound().size() + getUpperbound().size() + getStep().size();
3220 numOptional += getGangOperands().size();
3221 numOptional += getVectorOperands().size();
3222 numOptional += getWorkerNumOperands().size();
3223 numOptional += getTileOperands().size();
3224 numOptional += getCacheOperands().size();
3225 return getOperand(numOptional + i);
3226}
3227
3228bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3229
3230bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3231 return hasDeviceType(getAuto_(), deviceType);
3232}
3233
3234bool LoopOp::hasIndependent() {
3235 return hasIndependent(mlir::acc::DeviceType::None);
3236}
3237
3238bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3239 return hasDeviceType(getIndependent(), deviceType);
3240}
3241
3242bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3243
3244bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3245 return hasDeviceType(getSeq(), deviceType);
3246}
3247
3248mlir::Value LoopOp::getVectorValue() {
3249 return getVectorValue(mlir::acc::DeviceType::None);
3250}
3251
3252mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3253 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3254 getVectorOperands(), deviceType);
3255}
3256
3257bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3258
3259bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3260 return hasDeviceType(getVector(), deviceType);
3261}
3262
3263mlir::Value LoopOp::getWorkerValue() {
3264 return getWorkerValue(mlir::acc::DeviceType::None);
3265}
3266
3267mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3268 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3269 getWorkerNumOperands(), deviceType);
3270}
3271
3272bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3273
3274bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3275 return hasDeviceType(getWorker(), deviceType);
3276}
3277
3278mlir::Operation::operand_range LoopOp::getTileValues() {
3279 return getTileValues(mlir::acc::DeviceType::None);
3280}
3281
3283LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3284 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3285 getTileOperandsSegments(), deviceType);
3286}
3287
3288std::optional<int64_t> LoopOp::getCollapseValue() {
3289 return getCollapseValue(mlir::acc::DeviceType::None);
3290}
3291
3292std::optional<int64_t>
3293LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3294 if (!getCollapseAttr())
3295 return std::nullopt;
3296 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3297 auto intAttr =
3298 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3299 return intAttr.getValue().getZExtValue();
3300 }
3301 return std::nullopt;
3302}
3303
3304mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3305 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3306}
3307
3308mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3309 mlir::acc::DeviceType deviceType) {
3310 if (getGangOperands().empty())
3311 return {};
3312 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3313 int32_t nbOperandsBefore = 0;
3314 for (unsigned i = 0; i < *pos; ++i)
3315 nbOperandsBefore += (*getGangOperandsSegments())[i];
3317 getGangOperands()
3318 .drop_front(nbOperandsBefore)
3319 .take_front((*getGangOperandsSegments())[*pos]);
3320
3321 int32_t argTypeIdx = nbOperandsBefore;
3322 for (auto value : values) {
3323 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3324 (*getGangOperandsArgType())[argTypeIdx]);
3325 if (gangArgTypeAttr.getValue() == gangArgType)
3326 return value;
3327 ++argTypeIdx;
3328 }
3329 }
3330 return {};
3331}
3332
3333bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3334
3335bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3336 return hasDeviceType(getGang(), deviceType);
3337}
3338
3339llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3340 return {&getRegion()};
3341}
3342
3343/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3344/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3345/// `(` ssa-id-and-type-list `)`
3346/// region
3347ParseResult
3350 SmallVectorImpl<Type> &lowerboundType,
3352 SmallVectorImpl<Type> &upperboundType,
3354 SmallVectorImpl<Type> &stepType) {
3355
3357 if (succeeded(
3358 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3359 if (parser.parseLParen() ||
3360 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3361 /*allowType=*/true) ||
3362 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3363 parser.parseOperandList(lowerbound, inductionVars.size(),
3365 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3366 parser.parseKeyword("to") || parser.parseLParen() ||
3367 parser.parseOperandList(upperbound, inductionVars.size(),
3369 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3370 parser.parseKeyword("step") || parser.parseLParen() ||
3371 parser.parseOperandList(step, inductionVars.size(),
3373 parser.parseColonTypeList(stepType) || parser.parseRParen())
3374 return failure();
3375 }
3376 return parser.parseRegion(region, inductionVars);
3377}
3378
3380 ValueRange lowerbound, TypeRange lowerboundType,
3381 ValueRange upperbound, TypeRange upperboundType,
3382 ValueRange steps, TypeRange stepType) {
3383 ValueRange regionArgs = region.front().getArguments();
3384 if (!regionArgs.empty()) {
3385 p << acc::LoopOp::getControlKeyword() << "(";
3386 llvm::interleaveComma(regionArgs, p,
3387 [&p](Value v) { p << v << " : " << v.getType(); });
3388 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3389 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3390 << " : " << stepType << ") ";
3391 }
3392 p.printRegion(region, /*printEntryBlockArgs=*/false);
3393}
3394
3395void acc::LoopOp::addSeq(MLIRContext *context,
3396 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3397 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3398 effectiveDeviceTypes));
3399}
3400
3401void acc::LoopOp::addIndependent(
3402 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3403 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3404 context, getIndependentAttr(), effectiveDeviceTypes));
3405}
3406
3407void acc::LoopOp::addAuto(MLIRContext *context,
3408 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3409 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3410 effectiveDeviceTypes));
3411}
3412
3413void acc::LoopOp::setCollapseForDeviceTypes(
3414 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3415 llvm::APInt value) {
3418
3419 assert((getCollapseAttr() == nullptr) ==
3420 (getCollapseDeviceTypeAttr() == nullptr));
3421 assert(value.getBitWidth() == 64);
3422
3423 if (getCollapseAttr()) {
3424 for (const auto &existing :
3425 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3426 newValues.push_back(std::get<0>(existing));
3427 newDeviceTypes.push_back(std::get<1>(existing));
3428 }
3429 }
3430
3431 if (effectiveDeviceTypes.empty()) {
3432 // If the effective device-types list is empty, this is before there are any
3433 // being applied by device_type, so this should be added as a 'none'.
3434 newValues.push_back(
3435 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3436 newDeviceTypes.push_back(
3437 acc::DeviceTypeAttr::get(context, DeviceType::None));
3438 } else {
3439 for (DeviceType dt : effectiveDeviceTypes) {
3440 newValues.push_back(
3441 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3442 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3443 }
3444 }
3445
3446 setCollapseAttr(ArrayAttr::get(context, newValues));
3447 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3448}
3449
3450void acc::LoopOp::setTileForDeviceTypes(
3451 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3452 ValueRange values) {
3454 if (getTileOperandsSegments())
3455 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3456
3457 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3458 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3459 getTileOperandsMutable(), segments));
3460
3461 setTileOperandsSegments(segments);
3462}
3463
3464void acc::LoopOp::addVectorOperand(
3465 MLIRContext *context, mlir::Value newValue,
3466 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3467 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3468 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3469 newValue, getVectorOperandsMutable()));
3470}
3471
3472void acc::LoopOp::addEmptyVector(
3473 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3474 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3475 effectiveDeviceTypes));
3476}
3477
3478void acc::LoopOp::addWorkerNumOperand(
3479 MLIRContext *context, mlir::Value newValue,
3480 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3481 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3482 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3483 newValue, getWorkerNumOperandsMutable()));
3484}
3485
3486void acc::LoopOp::addEmptyWorker(
3487 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3488 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3489 effectiveDeviceTypes));
3490}
3491
3492void acc::LoopOp::addEmptyGang(
3493 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3494 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3495 effectiveDeviceTypes));
3496}
3497
3498bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3499 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3500 return attr.getValue() == dt;
3501 };
3502 auto testFromArr = [=](ArrayAttr arr) -> bool {
3503 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3504 };
3505
3506 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3507 return true;
3508 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3509 return true;
3510 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3511 return true;
3512
3513 return false;
3514}
3515
3516bool acc::LoopOp::hasDefaultGangWorkerVector() {
3517 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3518 hasGang() || getGangValue(GangArgType::Num) ||
3519 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3520}
3521
3522acc::LoopParMode
3523acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3524 if (hasSeq(deviceType))
3525 return LoopParMode::loop_seq;
3526 if (hasAuto(deviceType))
3527 return LoopParMode::loop_auto;
3528 if (hasIndependent(deviceType))
3529 return LoopParMode::loop_independent;
3530 if (hasSeq())
3531 return LoopParMode::loop_seq;
3532 if (hasAuto())
3533 return LoopParMode::loop_auto;
3534 assert(hasIndependent() &&
3535 "loop must have default auto, seq, or independent");
3536 return LoopParMode::loop_independent;
3537}
3538
3539void acc::LoopOp::addGangOperands(
3540 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3543 if (std::optional<ArrayRef<int32_t>> existingSegments =
3544 getGangOperandsSegments())
3545 llvm::copy(*existingSegments, std::back_inserter(segments));
3546
3547 unsigned beforeCount = segments.size();
3548
3549 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3550 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3551 getGangOperandsMutable(), segments));
3552
3553 setGangOperandsSegments(segments);
3554
3555 // This is a bit of extra work to make sure we update the 'types' correctly by
3556 // adding to the types collection the correct number of times. We could
3557 // potentially add something similar to the
3558 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3559 // excessive for a one-off case.
3560 unsigned numAdded = segments.size() - beforeCount;
3561
3562 if (numAdded > 0) {
3564 if (getGangOperandsArgTypeAttr())
3565 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3566
3567 for (auto i : llvm::index_range(0u, numAdded)) {
3568 llvm::transform(argTypes, std::back_inserter(gangTypes),
3569 [=](mlir::acc::GangArgType gangTy) {
3570 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3571 });
3572 (void)i;
3573 }
3574
3575 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3576 }
3577}
3578
3579void acc::LoopOp::addPrivatization(MLIRContext *context,
3580 mlir::acc::PrivateOp op,
3581 mlir::acc::PrivateRecipeOp recipe) {
3582 getPrivateOperandsMutable().append(op.getResult());
3583
3585
3586 if (getPrivatizationRecipesAttr())
3587 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3588
3589 recipes.push_back(
3590 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3591 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3592}
3593
3594void acc::LoopOp::addFirstPrivatization(
3595 MLIRContext *context, mlir::acc::FirstprivateOp op,
3596 mlir::acc::FirstprivateRecipeOp recipe) {
3597 getFirstprivateOperandsMutable().append(op.getResult());
3598
3600
3601 if (getFirstprivatizationRecipesAttr())
3602 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3603
3604 recipes.push_back(
3605 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3606 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3607}
3608
3609void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3610 mlir::acc::ReductionRecipeOp recipe) {
3611 getReductionOperandsMutable().append(op.getResult());
3612
3614
3615 if (getReductionRecipesAttr())
3616 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3617
3618 recipes.push_back(
3619 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3620 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3621}
3622
3623//===----------------------------------------------------------------------===//
3624// DataOp
3625//===----------------------------------------------------------------------===//
3626
3627LogicalResult acc::DataOp::verify() {
3628 // 2.6.5. Data Construct restriction
3629 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3630 // attach, or default clause must appear on a data construct.
3631 if (getOperands().empty() && !getDefaultAttr())
3632 return emitError("at least one operand or the default attribute "
3633 "must appear on the data operation");
3634
3635 for (mlir::Value operand : getDataClauseOperands())
3636 if (isa<BlockArgument>(operand) ||
3637 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3638 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3639 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3640 operand.getDefiningOp()))
3641 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3642 "as defining op");
3643
3645 return failure();
3646
3647 return success();
3648}
3649
3650unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3651
3652Value DataOp::getDataOperand(unsigned i) {
3653 unsigned numOptional = getIfCond() ? 1 : 0;
3654 numOptional += getAsyncOperands().size() ? 1 : 0;
3655 numOptional += getWaitOperands().size();
3656 return getOperand(numOptional + i);
3657}
3658
3659bool acc::DataOp::hasAsyncOnly() {
3660 return hasAsyncOnly(mlir::acc::DeviceType::None);
3661}
3662
3663bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3664 return hasDeviceType(getAsyncOnly(), deviceType);
3665}
3666
3667mlir::Value DataOp::getAsyncValue() {
3668 return getAsyncValue(mlir::acc::DeviceType::None);
3669}
3670
3671mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3673 getAsyncOperands(), deviceType);
3674}
3675
3676bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3677
3678bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3679 return hasDeviceType(getWaitOnly(), deviceType);
3680}
3681
3682mlir::Operation::operand_range DataOp::getWaitValues() {
3683 return getWaitValues(mlir::acc::DeviceType::None);
3684}
3685
3687DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3689 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3690 getHasWaitDevnum(), deviceType);
3691}
3692
3693mlir::Value DataOp::getWaitDevnum() {
3694 return getWaitDevnum(mlir::acc::DeviceType::None);
3695}
3696
3697mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3698 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3699 getWaitOperandsSegments(), getHasWaitDevnum(),
3700 deviceType);
3701}
3702
3703void acc::DataOp::addAsyncOnly(
3704 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3705 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3706 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3707}
3708
3709void acc::DataOp::addAsyncOperand(
3710 MLIRContext *context, mlir::Value newValue,
3711 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3712 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3713 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3714 getAsyncOperandsMutable()));
3715}
3716
3717void acc::DataOp::addWaitOnly(MLIRContext *context,
3718 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3719 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3720 effectiveDeviceTypes));
3721}
3722
3723void acc::DataOp::addWaitOperands(
3724 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3725 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3726
3728 if (getWaitOperandsSegments())
3729 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3730
3731 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3732 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3733 getWaitOperandsMutable(), segments));
3734 setWaitOperandsSegments(segments);
3735
3737 if (getHasWaitDevnumAttr())
3738 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3739 hasDevnums.insert(
3740 hasDevnums.end(),
3741 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3742 mlir::BoolAttr::get(context, hasDevnum));
3743 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3744}
3745
3746//===----------------------------------------------------------------------===//
3747// ExitDataOp
3748//===----------------------------------------------------------------------===//
3749
3750LogicalResult acc::ExitDataOp::verify() {
3751 // 2.6.6. Data Exit Directive restriction
3752 // At least one copyout, delete, or detach clause must appear on an exit data
3753 // directive.
3754 if (getDataClauseOperands().empty())
3755 return emitError("at least one operand must be present in dataOperands on "
3756 "the exit data operation");
3757
3758 // The async attribute represent the async clause without value. Therefore the
3759 // attribute and operand cannot appear at the same time.
3760 if (getAsyncOperand() && getAsync())
3761 return emitError("async attribute cannot appear with asyncOperand");
3762
3763 // The wait attribute represent the wait clause without values. Therefore the
3764 // attribute and operands cannot appear at the same time.
3765 if (!getWaitOperands().empty() && getWait())
3766 return emitError("wait attribute cannot appear with waitOperands");
3767
3768 if (getWaitDevnum() && getWaitOperands().empty())
3769 return emitError("wait_devnum cannot appear without waitOperands");
3770
3771 return success();
3772}
3773
3774unsigned ExitDataOp::getNumDataOperands() {
3775 return getDataClauseOperands().size();
3776}
3777
3778Value ExitDataOp::getDataOperand(unsigned i) {
3779 unsigned numOptional = getIfCond() ? 1 : 0;
3780 numOptional += getAsyncOperand() ? 1 : 0;
3781 numOptional += getWaitDevnum() ? 1 : 0;
3782 return getOperand(getWaitOperands().size() + numOptional + i);
3783}
3784
3785void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3786 MLIRContext *context) {
3787 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3788}
3789
3790void ExitDataOp::addAsyncOnly(MLIRContext *context,
3791 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3792 assert(effectiveDeviceTypes.empty());
3793 assert(!getAsyncAttr());
3794 assert(!getAsyncOperand());
3795
3796 setAsyncAttr(mlir::UnitAttr::get(context));
3797}
3798
3799void ExitDataOp::addAsyncOperand(
3800 MLIRContext *context, mlir::Value newValue,
3801 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3802 assert(effectiveDeviceTypes.empty());
3803 assert(!getAsyncAttr());
3804 assert(!getAsyncOperand());
3805
3806 getAsyncOperandMutable().append(newValue);
3807}
3808
3809void ExitDataOp::addWaitOnly(MLIRContext *context,
3810 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3811 assert(effectiveDeviceTypes.empty());
3812 assert(!getWaitAttr());
3813 assert(getWaitOperands().empty());
3814 assert(!getWaitDevnum());
3815
3816 setWaitAttr(mlir::UnitAttr::get(context));
3817}
3818
3819void ExitDataOp::addWaitOperands(
3820 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3821 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3822 assert(effectiveDeviceTypes.empty());
3823 assert(!getWaitAttr());
3824 assert(getWaitOperands().empty());
3825 assert(!getWaitDevnum());
3826
3827 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3828 // operands list.
3829 if (hasDevnum) {
3830 getWaitDevnumMutable().append(newValues.front());
3831 newValues = newValues.drop_front();
3832 }
3833
3834 getWaitOperandsMutable().append(newValues);
3835}
3836
3837//===----------------------------------------------------------------------===//
3838// EnterDataOp
3839//===----------------------------------------------------------------------===//
3840
3841LogicalResult acc::EnterDataOp::verify() {
3842 // 2.6.6. Data Enter Directive restriction
3843 // At least one copyin, create, or attach clause must appear on an enter data
3844 // directive.
3845 if (getDataClauseOperands().empty())
3846 return emitError("at least one operand must be present in dataOperands on "
3847 "the enter data operation");
3848
3849 // The async attribute represent the async clause without value. Therefore the
3850 // attribute and operand cannot appear at the same time.
3851 if (getAsyncOperand() && getAsync())
3852 return emitError("async attribute cannot appear with asyncOperand");
3853
3854 // The wait attribute represent the wait clause without values. Therefore the
3855 // attribute and operands cannot appear at the same time.
3856 if (!getWaitOperands().empty() && getWait())
3857 return emitError("wait attribute cannot appear with waitOperands");
3858
3859 if (getWaitDevnum() && getWaitOperands().empty())
3860 return emitError("wait_devnum cannot appear without waitOperands");
3861
3862 for (mlir::Value operand : getDataClauseOperands())
3863 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3864 operand.getDefiningOp()))
3865 return emitError("expect data entry operation as defining op");
3866
3867 return success();
3868}
3869
3870unsigned EnterDataOp::getNumDataOperands() {
3871 return getDataClauseOperands().size();
3872}
3873
3874Value EnterDataOp::getDataOperand(unsigned i) {
3875 unsigned numOptional = getIfCond() ? 1 : 0;
3876 numOptional += getAsyncOperand() ? 1 : 0;
3877 numOptional += getWaitDevnum() ? 1 : 0;
3878 return getOperand(getWaitOperands().size() + numOptional + i);
3879}
3880
3881void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3882 MLIRContext *context) {
3883 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3884}
3885
3886void EnterDataOp::addAsyncOnly(
3887 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3888 assert(effectiveDeviceTypes.empty());
3889 assert(!getAsyncAttr());
3890 assert(!getAsyncOperand());
3891
3892 setAsyncAttr(mlir::UnitAttr::get(context));
3893}
3894
3895void EnterDataOp::addAsyncOperand(
3896 MLIRContext *context, mlir::Value newValue,
3897 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3898 assert(effectiveDeviceTypes.empty());
3899 assert(!getAsyncAttr());
3900 assert(!getAsyncOperand());
3901
3902 getAsyncOperandMutable().append(newValue);
3903}
3904
3905void EnterDataOp::addWaitOnly(MLIRContext *context,
3906 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3907 assert(effectiveDeviceTypes.empty());
3908 assert(!getWaitAttr());
3909 assert(getWaitOperands().empty());
3910 assert(!getWaitDevnum());
3911
3912 setWaitAttr(mlir::UnitAttr::get(context));
3913}
3914
3915void EnterDataOp::addWaitOperands(
3916 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3917 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3918 assert(effectiveDeviceTypes.empty());
3919 assert(!getWaitAttr());
3920 assert(getWaitOperands().empty());
3921 assert(!getWaitDevnum());
3922
3923 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3924 // operands list.
3925 if (hasDevnum) {
3926 getWaitDevnumMutable().append(newValues.front());
3927 newValues = newValues.drop_front();
3928 }
3929
3930 getWaitOperandsMutable().append(newValues);
3931}
3932
3933//===----------------------------------------------------------------------===//
3934// AtomicReadOp
3935//===----------------------------------------------------------------------===//
3936
3937LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3938
3939//===----------------------------------------------------------------------===//
3940// AtomicWriteOp
3941//===----------------------------------------------------------------------===//
3942
3943LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3944
3945//===----------------------------------------------------------------------===//
3946// AtomicUpdateOp
3947//===----------------------------------------------------------------------===//
3948
3949LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3950 PatternRewriter &rewriter) {
3951 if (op.isNoOp()) {
3952 rewriter.eraseOp(op);
3953 return success();
3954 }
3955
3956 if (Value writeVal = op.getWriteOpVal()) {
3957 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
3958 op.getIfCond());
3959 return success();
3960 }
3961
3962 return failure();
3963}
3964
3965LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3966
3967LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3968
3969//===----------------------------------------------------------------------===//
3970// AtomicCaptureOp
3971//===----------------------------------------------------------------------===//
3972
3973AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3974 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3975 return op;
3976 return dyn_cast<AtomicReadOp>(getSecondOp());
3977}
3978
3979AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3980 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3981 return op;
3982 return dyn_cast<AtomicWriteOp>(getSecondOp());
3983}
3984
3985AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3986 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3987 return op;
3988 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3989}
3990
3991LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3992
3993//===----------------------------------------------------------------------===//
3994// DeclareEnterOp
3995//===----------------------------------------------------------------------===//
3996
3997template <typename Op>
3998static LogicalResult
4000 bool requireAtLeastOneOperand = true) {
4001 if (operands.empty() && requireAtLeastOneOperand)
4002 return emitError(
4003 op->getLoc(),
4004 "at least one operand must appear on the declare operation");
4005
4006 for (mlir::Value operand : operands) {
4007 if (isa<BlockArgument>(operand) ||
4008 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4009 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4010 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4011 operand.getDefiningOp()))
4012 return op.emitError(
4013 "expect valid declare data entry operation or acc.getdeviceptr "
4014 "as defining op");
4015
4016 mlir::Value var{getVar(operand.getDefiningOp())};
4017 assert(var && "declare operands can only be data entry operations which "
4018 "must have var");
4019 (void)var;
4020 std::optional<mlir::acc::DataClause> dataClauseOptional{
4021 getDataClause(operand.getDefiningOp())};
4022 assert(dataClauseOptional.has_value() &&
4023 "declare operands can only be data entry operations which must have "
4024 "dataClause");
4025 (void)dataClauseOptional;
4026 }
4027
4028 return success();
4029}
4030
4031LogicalResult acc::DeclareEnterOp::verify() {
4032 return checkDeclareOperands(*this, this->getDataClauseOperands());
4033}
4034
4035//===----------------------------------------------------------------------===//
4036// DeclareExitOp
4037//===----------------------------------------------------------------------===//
4038
4039LogicalResult acc::DeclareExitOp::verify() {
4040 if (getToken())
4041 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4042 /*requireAtLeastOneOperand=*/false);
4043 return checkDeclareOperands(*this, this->getDataClauseOperands());
4044}
4045
4046//===----------------------------------------------------------------------===//
4047// DeclareOp
4048//===----------------------------------------------------------------------===//
4049
4050LogicalResult acc::DeclareOp::verify() {
4051 return checkDeclareOperands(*this, this->getDataClauseOperands());
4052}
4053
4054//===----------------------------------------------------------------------===//
4055// RoutineOp
4056//===----------------------------------------------------------------------===//
4057
4058static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4059 acc::DeviceType dtype) {
4060 unsigned parallelism = 0;
4061 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4062 parallelism += op.hasWorker(dtype) ? 1 : 0;
4063 parallelism += op.hasVector(dtype) ? 1 : 0;
4064 parallelism += op.hasSeq(dtype) ? 1 : 0;
4065 return parallelism;
4066}
4067
4068LogicalResult acc::RoutineOp::verify() {
4069 unsigned baseParallelism =
4070 getParallelismForDeviceType(*this, acc::DeviceType::None);
4071
4072 if (baseParallelism > 1)
4073 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4074 "be present at the same time";
4075
4076 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4077 ++dtypeInt) {
4078 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4079 if (dtype == acc::DeviceType::None)
4080 continue;
4081 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4082
4083 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4084 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4085 "be present at the same time";
4086 }
4087
4088 return success();
4089}
4090
4091static ParseResult parseBindName(OpAsmParser &parser,
4092 mlir::ArrayAttr &bindIdName,
4093 mlir::ArrayAttr &bindStrName,
4094 mlir::ArrayAttr &deviceIdTypes,
4095 mlir::ArrayAttr &deviceStrTypes) {
4096 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4097 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4098 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4099 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4100
4101 if (failed(parser.parseCommaSeparatedList([&]() {
4102 mlir::Attribute newAttr;
4103 bool isSymbolRefAttr;
4104 auto parseResult = parser.parseAttribute(newAttr);
4105 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4106 bindIdNameAttrs.push_back(symbolRefAttr);
4107 isSymbolRefAttr = true;
4108 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4109 bindStrNameAttrs.push_back(stringAttr);
4110 isSymbolRefAttr = false;
4111 }
4112 if (parseResult)
4113 return failure();
4114 if (failed(parser.parseOptionalLSquare())) {
4115 if (isSymbolRefAttr) {
4116 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4117 parser.getContext(), mlir::acc::DeviceType::None));
4118 } else {
4119 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4120 parser.getContext(), mlir::acc::DeviceType::None));
4121 }
4122 } else {
4123 if (isSymbolRefAttr) {
4124 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4125 parser.parseRSquare())
4126 return failure();
4127 } else {
4128 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4129 parser.parseRSquare())
4130 return failure();
4131 }
4132 }
4133 return success();
4134 })))
4135 return failure();
4136
4137 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4138 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4139 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4140 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4141
4142 return success();
4143}
4144
4146 std::optional<mlir::ArrayAttr> bindIdName,
4147 std::optional<mlir::ArrayAttr> bindStrName,
4148 std::optional<mlir::ArrayAttr> deviceIdTypes,
4149 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4150 // Create combined vectors for all bind names and device types
4153
4154 // Append bindIdName and deviceIdTypes
4155 if (hasDeviceTypeValues(deviceIdTypes)) {
4156 allBindNames.append(bindIdName->begin(), bindIdName->end());
4157 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4158 }
4159
4160 // Append bindStrName and deviceStrTypes
4161 if (hasDeviceTypeValues(deviceStrTypes)) {
4162 allBindNames.append(bindStrName->begin(), bindStrName->end());
4163 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4164 }
4165
4166 // Print the combined sequence
4167 if (!allBindNames.empty())
4168 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4169 [&](const auto &pair) {
4170 p << std::get<0>(pair);
4171 printSingleDeviceType(p, std::get<1>(pair));
4172 });
4173}
4174
4175static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4176 mlir::ArrayAttr &gang,
4177 mlir::ArrayAttr &gangDim,
4178 mlir::ArrayAttr &gangDimDeviceTypes) {
4179
4180 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4181 gangDimDeviceTypeAttrs;
4182 bool needCommaBeforeOperands = false;
4183
4184 // Gang keyword only
4185 if (failed(parser.parseOptionalLParen())) {
4186 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4187 parser.getContext(), mlir::acc::DeviceType::None));
4188 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4189 return success();
4190 }
4191
4192 // Parse keyword only attributes
4193 if (succeeded(parser.parseOptionalLSquare())) {
4194 if (failed(parser.parseCommaSeparatedList([&]() {
4195 if (parser.parseAttribute(gangAttrs.emplace_back()))
4196 return failure();
4197 return success();
4198 })))
4199 return failure();
4200 if (parser.parseRSquare())
4201 return failure();
4202 needCommaBeforeOperands = true;
4203 }
4204
4205 if (needCommaBeforeOperands && failed(parser.parseComma()))
4206 return failure();
4207
4208 if (failed(parser.parseCommaSeparatedList([&]() {
4209 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4210 parser.parseColon() ||
4211 parser.parseAttribute(gangDimAttrs.emplace_back()))
4212 return failure();
4213 if (succeeded(parser.parseOptionalLSquare())) {
4214 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4215 parser.parseRSquare())
4216 return failure();
4217 } else {
4218 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4219 parser.getContext(), mlir::acc::DeviceType::None));
4220 }
4221 return success();
4222 })))
4223 return failure();
4224
4225 if (failed(parser.parseRParen()))
4226 return failure();
4227
4228 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4229 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4230 gangDimDeviceTypes =
4231 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4232
4233 return success();
4234}
4235
4237 std::optional<mlir::ArrayAttr> gang,
4238 std::optional<mlir::ArrayAttr> gangDim,
4239 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4240
4241 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4242 gang->size() == 1) {
4243 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4244 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4245 return;
4246 }
4247
4248 p << "(";
4249
4250 printDeviceTypes(p, gang);
4251
4252 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4253 p << ", ";
4254
4255 if (hasDeviceTypeValues(gangDimDeviceTypes))
4256 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4257 [&](const auto &pair) {
4258 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4259 p << std::get<0>(pair);
4260 printSingleDeviceType(p, std::get<1>(pair));
4261 });
4262
4263 p << ")";
4264}
4265
4266static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4267 mlir::ArrayAttr &deviceTypes) {
4269 // Keyword only
4270 if (failed(parser.parseOptionalLParen())) {
4271 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4272 parser.getContext(), mlir::acc::DeviceType::None));
4273 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4274 return success();
4275 }
4276
4277 // Parse device type attributes
4278 if (succeeded(parser.parseOptionalLSquare())) {
4279 if (failed(parser.parseCommaSeparatedList([&]() {
4280 if (parser.parseAttribute(attributes.emplace_back()))
4281 return failure();
4282 return success();
4283 })))
4284 return failure();
4285 if (parser.parseRSquare() || parser.parseRParen())
4286 return failure();
4287 }
4288 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4289 return success();
4290}
4291
4292static void
4294 std::optional<mlir::ArrayAttr> deviceTypes) {
4295
4296 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4297 auto deviceTypeAttr =
4298 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4299 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4300 return;
4301 }
4302
4303 if (!hasDeviceTypeValues(deviceTypes))
4304 return;
4305
4306 p << "([";
4307 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4308 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4309 p << dTypeAttr;
4310 });
4311 p << "])";
4312}
4313
4314bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4315
4316bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4317 return hasDeviceType(getWorker(), deviceType);
4318}
4319
4320bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4321
4322bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4323 return hasDeviceType(getVector(), deviceType);
4324}
4325
4326bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4327
4328bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4329 return hasDeviceType(getSeq(), deviceType);
4330}
4331
4332std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4333RoutineOp::getBindNameValue() {
4334 return getBindNameValue(mlir::acc::DeviceType::None);
4335}
4336
4337std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4338RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4339 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4340 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4341 return std::nullopt;
4342 }
4343
4344 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4345 auto attr = (*getBindIdName())[*pos];
4346 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4347 assert(symbolRefAttr && "expected SymbolRef");
4348 return symbolRefAttr;
4349 }
4350
4351 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4352 auto attr = (*getBindStrName())[*pos];
4353 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4354 assert(stringAttr && "expected String");
4355 return stringAttr;
4356 }
4357
4358 return std::nullopt;
4359}
4360
4361bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4362
4363bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4364 return hasDeviceType(getGang(), deviceType);
4365}
4366
4367std::optional<int64_t> RoutineOp::getGangDimValue() {
4368 return getGangDimValue(mlir::acc::DeviceType::None);
4369}
4370
4371std::optional<int64_t>
4372RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4373 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4374 return std::nullopt;
4375 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4376 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4377 return intAttr.getInt();
4378 }
4379 return std::nullopt;
4380}
4381
4382//===----------------------------------------------------------------------===//
4383// InitOp
4384//===----------------------------------------------------------------------===//
4385
4386LogicalResult acc::InitOp::verify() {
4387 Operation *currOp = *this;
4388 while ((currOp = currOp->getParentOp()))
4389 if (isComputeOperation(currOp))
4390 return emitOpError("cannot be nested in a compute operation");
4391 return success();
4392}
4393
4394void acc::InitOp::addDeviceType(MLIRContext *context,
4395 mlir::acc::DeviceType deviceType) {
4397 if (getDeviceTypesAttr())
4398 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4399
4400 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4401 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4402}
4403
4404//===----------------------------------------------------------------------===//
4405// ShutdownOp
4406//===----------------------------------------------------------------------===//
4407
4408LogicalResult acc::ShutdownOp::verify() {
4409 Operation *currOp = *this;
4410 while ((currOp = currOp->getParentOp()))
4411 if (isComputeOperation(currOp))
4412 return emitOpError("cannot be nested in a compute operation");
4413 return success();
4414}
4415
4416void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4417 mlir::acc::DeviceType deviceType) {
4419 if (getDeviceTypesAttr())
4420 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4421
4422 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4423 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4424}
4425
4426//===----------------------------------------------------------------------===//
4427// SetOp
4428//===----------------------------------------------------------------------===//
4429
4430LogicalResult acc::SetOp::verify() {
4431 Operation *currOp = *this;
4432 while ((currOp = currOp->getParentOp()))
4433 if (isComputeOperation(currOp))
4434 return emitOpError("cannot be nested in a compute operation");
4435 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4436 return emitOpError("at least one default_async, device_num, or device_type "
4437 "operand must appear");
4438 return success();
4439}
4440
4441//===----------------------------------------------------------------------===//
4442// UpdateOp
4443//===----------------------------------------------------------------------===//
4444
4445LogicalResult acc::UpdateOp::verify() {
4446 // At least one of host or device should have a value.
4447 if (getDataClauseOperands().empty())
4448 return emitError("at least one value must be present in dataOperands");
4449
4451 getAsyncOperandsDeviceTypeAttr(),
4452 "async")))
4453 return failure();
4454
4456 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4457 getWaitOperandsDeviceTypeAttr(), "wait")))
4458 return failure();
4459
4461 return failure();
4462
4463 for (mlir::Value operand : getDataClauseOperands())
4464 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4465 operand.getDefiningOp()))
4466 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4467 "as defining op");
4468
4469 return success();
4470}
4471
4472unsigned UpdateOp::getNumDataOperands() {
4473 return getDataClauseOperands().size();
4474}
4475
4476Value UpdateOp::getDataOperand(unsigned i) {
4477 unsigned numOptional = getAsyncOperands().size();
4478 numOptional += getIfCond() ? 1 : 0;
4479 return getOperand(getWaitOperands().size() + numOptional + i);
4480}
4481
4482void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4483 MLIRContext *context) {
4484 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4485}
4486
4487bool UpdateOp::hasAsyncOnly() {
4488 return hasAsyncOnly(mlir::acc::DeviceType::None);
4489}
4490
4491bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4492 return hasDeviceType(getAsyncOnly(), deviceType);
4493}
4494
4495mlir::Value UpdateOp::getAsyncValue() {
4496 return getAsyncValue(mlir::acc::DeviceType::None);
4497}
4498
4499mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4501 return {};
4502
4503 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4504 return getAsyncOperands()[*pos];
4505
4506 return {};
4507}
4508
4509bool UpdateOp::hasWaitOnly() {
4510 return hasWaitOnly(mlir::acc::DeviceType::None);
4511}
4512
4513bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4514 return hasDeviceType(getWaitOnly(), deviceType);
4515}
4516
4517mlir::Operation::operand_range UpdateOp::getWaitValues() {
4518 return getWaitValues(mlir::acc::DeviceType::None);
4519}
4520
4522UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4524 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4525 getHasWaitDevnum(), deviceType);
4526}
4527
4528mlir::Value UpdateOp::getWaitDevnum() {
4529 return getWaitDevnum(mlir::acc::DeviceType::None);
4530}
4531
4532mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4533 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4534 getWaitOperandsSegments(), getHasWaitDevnum(),
4535 deviceType);
4536}
4537
4538void UpdateOp::addAsyncOnly(MLIRContext *context,
4539 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4540 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4541 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4542}
4543
4544void UpdateOp::addAsyncOperand(
4545 MLIRContext *context, mlir::Value newValue,
4546 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4547 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4548 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4549 getAsyncOperandsMutable()));
4550}
4551
4552void UpdateOp::addWaitOnly(MLIRContext *context,
4553 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4554 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4555 effectiveDeviceTypes));
4556}
4557
4558void UpdateOp::addWaitOperands(
4559 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4560 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4561
4563 if (getWaitOperandsSegments())
4564 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4565
4566 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4567 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4568 getWaitOperandsMutable(), segments));
4569 setWaitOperandsSegments(segments);
4570
4572 if (getHasWaitDevnumAttr())
4573 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4574 hasDevnums.insert(
4575 hasDevnums.end(),
4576 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4577 mlir::BoolAttr::get(context, hasDevnum));
4578 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4579}
4580
4581//===----------------------------------------------------------------------===//
4582// WaitOp
4583//===----------------------------------------------------------------------===//
4584
4585LogicalResult acc::WaitOp::verify() {
4586 // The async attribute represent the async clause without value. Therefore the
4587 // attribute and operand cannot appear at the same time.
4588 if (getAsyncOperand() && getAsync())
4589 return emitError("async attribute cannot appear with asyncOperand");
4590
4591 if (getWaitDevnum() && getWaitOperands().empty())
4592 return emitError("wait_devnum cannot appear without waitOperands");
4593
4594 return success();
4595}
4596
4597#define GET_OP_CLASSES
4598#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4599
4600#define GET_ATTRDEF_CLASSES
4601#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4602
4603#define GET_TYPEDEF_CLASSES
4604#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4605
4606//===----------------------------------------------------------------------===//
4607// acc dialect utilities
4608//===----------------------------------------------------------------------===//
4609
4612 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4614 accDataClauseOp)
4615 .Case<ACC_DATA_ENTRY_OPS>(
4616 [&](auto entry) { return entry.getVarPtr(); })
4617 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4618 [&](auto exit) { return exit.getVarPtr(); })
4619 .Default([&](mlir::Operation *) {
4621 })};
4622 return varPtr;
4623}
4624
4626 auto varPtr{
4628 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4629 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4630 return varPtr;
4631}
4632
4634 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4635 .Case<ACC_DATA_ENTRY_OPS>(
4636 [&](auto entry) { return entry.getVarType(); })
4637 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4638 [&](auto exit) { return exit.getVarType(); })
4639 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4640 return varType;
4641}
4642
4645 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4647 accDataClauseOp)
4648 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4649 [&](auto dataClause) { return dataClause.getAccPtr(); })
4650 .Default([&](mlir::Operation *) {
4652 })};
4653 return accPtr;
4654}
4655
4657 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4659 [&](auto dataClause) { return dataClause.getAccVar(); })
4660 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4661 return accPtr;
4662}
4663
4665 auto varPtrPtr{
4667 .Case<ACC_DATA_ENTRY_OPS>(
4668 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4669 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4670 return varPtrPtr;
4671}
4672
4677 accDataClauseOp)
4678 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4680 dataClause.getBounds().begin(), dataClause.getBounds().end());
4681 })
4682 .Default([&](mlir::Operation *) {
4684 })};
4685 return bounds;
4686}
4687
4691 accDataClauseOp)
4692 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4694 dataClause.getAsyncOperands().begin(),
4695 dataClause.getAsyncOperands().end());
4696 })
4697 .Default([&](mlir::Operation *) {
4699 });
4700}
4701
4702mlir::ArrayAttr
4705 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4706 return dataClause.getAsyncOperandsDeviceTypeAttr();
4707 })
4708 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4709}
4710
4711mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4714 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4715 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4716}
4717
4718std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4719 auto name{
4721 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4722 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4723 return {};
4724 })};
4725 return name;
4726}
4727
4728std::optional<mlir::acc::DataClause>
4730 auto dataClause{
4732 accDataEntryOp)
4733 .Case<ACC_DATA_ENTRY_OPS>(
4734 [&](auto entry) { return entry.getDataClause(); })
4735 .Default([&](mlir::Operation *) { return std::nullopt; })};
4736 return dataClause;
4737}
4738
4740 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4741 .Case<ACC_DATA_ENTRY_OPS>(
4742 [&](auto entry) { return entry.getImplicit(); })
4743 .Default([&](mlir::Operation *) { return false; })};
4744 return implicit;
4745}
4746
4748 auto dataOperands{
4751 [&](auto entry) { return entry.getDataClauseOperands(); })
4752 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4753 return dataOperands;
4754}
4755
4758 auto dataOperands{
4761 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4762 .Default([&](mlir::Operation *) { return nullptr; })};
4763 return dataOperands;
4764}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:136
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4236
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:984
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:2979
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1568
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4091
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:2993
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:998
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:408
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2090
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:539
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:392
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:508
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2101
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2006
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:334
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4293
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2788
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2341
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:3999
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:468
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2295
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:352
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:450
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:484
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3348
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1497
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2137
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1650
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:476
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:517
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:363
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:376
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1876
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:493
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3379
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4266
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4175
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:1989
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2164
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2280
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1943
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
Definition OpenACC.cpp:1481
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2256
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:580
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2807
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1265
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2325
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:1920
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
Definition OpenACC.cpp:1512
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2237
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:338
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:2934
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2175
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:551
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:428
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1578
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4058
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:1926
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2361
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
Definition OpenACC.cpp:1461
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4145
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:53
false
Parses a map_entries map type from a string format back into its numeric value.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4656
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4625
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4644
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:4729
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:4757
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4674
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:4747
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4718
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:4739
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4689
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4664
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:180
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4711
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4633
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4611
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4703
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.