MLIR 22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248};
249
250struct LLVMPointerPointerLikeModel
251 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
252 LLVM::LLVMPointerType> {
253 Type getElementType(Type pointer) const { return Type(); }
254
255 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
257 Type valueType) const {
258 // For LLVM pointers, we need the valueType to determine what to load
259 if (!valueType)
260 return {};
261
262 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
263 }
264
265 bool genStore(Type pointer, OpBuilder &builder, Location loc,
266 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
267 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
268 return true;
269 }
270};
271
272struct MemrefAddressOfGlobalModel
273 : public AddressOfGlobalOpInterface::ExternalModel<
274 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
275 SymbolRefAttr getSymbol(Operation *op) const {
276 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
277 return getGlobalOp.getNameAttr();
278 }
279};
280
281struct MemrefGlobalVariableModel
282 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
283 memref::GlobalOp> {
284 bool isConstant(Operation *op) const {
285 auto globalOp = cast<memref::GlobalOp>(op);
286 return globalOp.getConstant();
287 }
288
289 Region *getInitRegion(Operation *op) const {
290 // GlobalOp uses attributes for initialization, not regions
291 return nullptr;
292 }
293};
294
295struct GPULaunchOffloadRegionModel
296 : public acc::OffloadRegionOpInterface::ExternalModel<
297 GPULaunchOffloadRegionModel, gpu::LaunchOp> {};
298
299/// Helper function for any of the times we need to modify an ArrayAttr based on
300/// a device type list. Returns a new ArrayAttr with all of the
301/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
302/// list is empty).
303mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
304 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
305 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
307 if (existingDeviceTypes)
308 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
309
310 if (newDeviceTypes.empty())
311 deviceTypes.push_back(
312 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
313
314 for (DeviceType dt : newDeviceTypes)
315 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
316
317 return mlir::ArrayAttr::get(context, deviceTypes);
318}
319
320/// Helper function for any of the times we need to add operands that are
321/// affected by a device type list. Returns a new ArrayAttr with all of the
322/// existingDeviceTypes, plus the effective new ones (or an added none, if the
323/// new list is empty). Additionally, adds the arguments to the argCollection
324/// the correct number of times. This will also update a 'segments' array, even
325/// if it won't be used.
326mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
327 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
328 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
329 mlir::MutableOperandRange argCollection,
330 llvm::SmallVector<int32_t> &segments) {
332 if (existingDeviceTypes)
333 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
334
335 if (newDeviceTypes.empty()) {
336 argCollection.append(arguments);
337 segments.push_back(arguments.size());
338 deviceTypes.push_back(
339 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
340 }
341
342 for (DeviceType dt : newDeviceTypes) {
343 argCollection.append(arguments);
344 segments.push_back(arguments.size());
345 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
346 }
347
348 return mlir::ArrayAttr::get(context, deviceTypes);
349}
350
351/// Overload for when the 'segments' aren't needed.
352mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
353 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
354 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
355 mlir::MutableOperandRange argCollection) {
357 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
358 newDeviceTypes, arguments,
359 argCollection, segments);
360}
361} // namespace
362
363//===----------------------------------------------------------------------===//
364// OpenACC operations
365//===----------------------------------------------------------------------===//
366
367void OpenACCDialect::initialize() {
368 addOperations<
369#define GET_OP_LIST
370#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
371 >();
372 addAttributes<
373#define GET_ATTRDEF_LIST
374#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
375 >();
376 addTypes<
377#define GET_TYPEDEF_LIST
378#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
379 >();
380
381 // By attaching interfaces here, we make the OpenACC dialect dependent on
382 // the other dialects. This is probably better than having dialects like LLVM
383 // and memref be dependent on OpenACC.
384 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
385 *getContext());
386 UnrankedMemRefType::attachInterface<
387 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
388 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
389 *getContext());
390
391 // Attach operation interfaces
392 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
393 *getContext());
394 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
395 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
396}
397
398//===----------------------------------------------------------------------===//
399// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
400// acc.kernel_environment / acc.data / acc.host_data / acc.loop
401//===----------------------------------------------------------------------===//
402
403/// Generic helper for single-region OpenACC ops that execute their body once
404/// and then return to the parent operation with their results (if any).
405static void
407 RegionBranchPoint point,
409 if (point.isParent()) {
410 regions.push_back(RegionSuccessor(&region));
411 return;
412 }
413
414 regions.push_back(RegionSuccessor(op, op->getResults()));
415}
416
417void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
419 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
420 regions);
421}
422
423void ParallelOp::getSuccessorRegions(
425 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
426 regions);
427}
428
429void SerialOp::getSuccessorRegions(RegionBranchPoint point,
431 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
432 regions);
433}
434
435void KernelEnvironmentOp::getSuccessorRegions(
437 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
438 regions);
439}
440
441void DataOp::getSuccessorRegions(RegionBranchPoint point,
443 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
444 regions);
445}
446
447void HostDataOp::getSuccessorRegions(
449 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
450 regions);
451}
452
453void LoopOp::getSuccessorRegions(RegionBranchPoint point,
455 // Unstructured loops: the body may contain arbitrary CFG and early exits.
456 // At the RegionBranch level, only model entry into the body and exit to the
457 // parent; any backedges are represented inside the region CFG.
458 if (getUnstructured()) {
459 if (point.isParent()) {
460 regions.push_back(RegionSuccessor(&getRegion()));
461 return;
462 }
463 regions.push_back(RegionSuccessor(getOperation(), getResults()));
464 return;
465 }
466
467 // Structured loops: model a loop-shaped region graph similar to scf.for.
468 regions.push_back(RegionSuccessor(&getRegion()));
469 regions.push_back(RegionSuccessor(getOperation(), getResults()));
470}
471
472//===----------------------------------------------------------------------===//
473// RegionBranchTerminatorOpInterface
474//===----------------------------------------------------------------------===//
475
477TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
478 // `acc.terminator` does not forward operands.
479 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
480}
481
482//===----------------------------------------------------------------------===//
483// device_type support helpers
484//===----------------------------------------------------------------------===//
485
486static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
487 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
488}
489
490static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
491 mlir::acc::DeviceType deviceType) {
492 if (!hasDeviceTypeValues(arrayAttr))
493 return false;
494
495 for (auto attr : *arrayAttr) {
496 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
497 if (deviceTypeAttr.getValue() == deviceType)
498 return true;
499 }
500
501 return false;
502}
503
505 std::optional<mlir::ArrayAttr> deviceTypes) {
506 if (!hasDeviceTypeValues(deviceTypes))
507 return;
508
509 p << "[";
510 llvm::interleaveComma(*deviceTypes, p,
511 [&](mlir::Attribute attr) { p << attr; });
512 p << "]";
513}
514
515static std::optional<unsigned> findSegment(ArrayAttr segments,
516 mlir::acc::DeviceType deviceType) {
517 unsigned segmentIdx = 0;
518 for (auto attr : segments) {
519 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
520 if (deviceTypeAttr.getValue() == deviceType)
521 return std::make_optional(segmentIdx);
522 ++segmentIdx;
523 }
524 return std::nullopt;
525}
526
528getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
530 std::optional<llvm::ArrayRef<int32_t>> segments,
531 mlir::acc::DeviceType deviceType) {
532 if (!arrayAttr)
533 return range.take_front(0);
534 if (auto pos = findSegment(*arrayAttr, deviceType)) {
535 int32_t nbOperandsBefore = 0;
536 for (unsigned i = 0; i < *pos; ++i)
537 nbOperandsBefore += (*segments)[i];
538 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
539 }
540 return range.take_front(0);
541}
542
543static mlir::Value
544getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
546 std::optional<llvm::ArrayRef<int32_t>> segments,
547 std::optional<mlir::ArrayAttr> hasWaitDevnum,
548 mlir::acc::DeviceType deviceType) {
549 if (!hasDeviceTypeValues(deviceTypeAttr))
550 return {};
551 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
552 if (hasWaitDevnum->getValue()[*pos])
553 return getValuesFromSegments(deviceTypeAttr, operands, segments,
554 deviceType)
555 .front();
556 return {};
557}
558
560getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
562 std::optional<llvm::ArrayRef<int32_t>> segments,
563 std::optional<mlir::ArrayAttr> hasWaitDevnum,
564 mlir::acc::DeviceType deviceType) {
565 auto range =
566 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
567 if (range.empty())
568 return range;
569 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
570 if (hasWaitDevnum && *hasWaitDevnum) {
571 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
572 if (boolAttr.getValue())
573 return range.drop_front(1); // first value is devnum
574 }
575 }
576 return range;
577}
578
579template <typename Op>
580static LogicalResult checkWaitAndAsyncConflict(Op op) {
581 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
582 ++dtypeInt) {
583 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
584
585 // The asyncOnly attribute represent the async clause without value.
586 // Therefore the attribute and operand cannot appear at the same time.
587 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
588 op.hasAsyncOnly(dtype))
589 return op.emitError(
590 "asyncOnly attribute cannot appear with asyncOperand");
591
592 // The wait attribute represent the wait clause without values. Therefore
593 // the attribute and operands cannot appear at the same time.
594 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
595 op.hasWaitOnly(dtype))
596 return op.emitError("wait attribute cannot appear with waitOperands");
597 }
598 return success();
599}
600
601template <typename Op>
602static LogicalResult checkVarAndVarType(Op op) {
603 if (!op.getVar())
604 return op.emitError("must have var operand");
605
606 // A variable must have a type that is either pointer-like or mappable.
607 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
608 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
609 return op.emitError("var must be mappable or pointer-like");
610
611 // When it is a pointer-like type, the varType must capture the target type.
612 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
613 op.getVarType() == op.getVar().getType())
614 return op.emitError("varType must capture the element type of var");
615
616 return success();
617}
618
619template <typename Op>
620static LogicalResult checkVarAndAccVar(Op op) {
621 if (op.getVar().getType() != op.getAccVar().getType())
622 return op.emitError("input and output types must match");
623
624 return success();
625}
626
627template <typename Op>
628static LogicalResult checkNoModifier(Op op) {
629 if (op.getModifiers() != acc::DataClauseModifier::none)
630 return op.emitError("no data clause modifiers are allowed");
631 return success();
632}
633
634template <typename Op>
635static LogicalResult
636checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
637 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
638 return op.emitError(
639 "invalid data clause modifiers: " +
640 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
641
642 return success();
643}
644
645template <typename OpT, typename RecipeOpT>
646static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
647 // Mappable types do not need a recipe because it is possible to generate one
648 // from its API. Reject reductions though because no API is available for them
649 // at this time.
650 if (mlir::acc::isMappableType(op.getVar().getType()) &&
651 !std::is_same_v<OpT, acc::ReductionOp>)
652 return success();
653
654 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
655 if (!operandRecipe)
656 return op->emitOpError() << "recipe expected for " << operandName;
657
658 auto decl =
660 if (!decl)
661 return op->emitOpError()
662 << "expected symbol reference " << operandRecipe << " to point to a "
663 << operandName << " declaration";
664 return success();
665}
666
667static ParseResult parseVar(mlir::OpAsmParser &parser,
669 // Either `var` or `varPtr` keyword is required.
670 if (failed(parser.parseOptionalKeyword("varPtr"))) {
671 if (failed(parser.parseKeyword("var")))
672 return failure();
673 }
674 if (failed(parser.parseLParen()))
675 return failure();
676 if (failed(parser.parseOperand(var)))
677 return failure();
678
679 return success();
680}
681
683 mlir::Value var) {
684 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
685 p << "varPtr(";
686 else
687 p << "var(";
688 p.printOperand(var);
689}
690
691static ParseResult parseAccVar(mlir::OpAsmParser &parser,
693 mlir::Type &accVarType) {
694 // Either `accVar` or `accPtr` keyword is required.
695 if (failed(parser.parseOptionalKeyword("accPtr"))) {
696 if (failed(parser.parseKeyword("accVar")))
697 return failure();
698 }
699 if (failed(parser.parseLParen()))
700 return failure();
701 if (failed(parser.parseOperand(var)))
702 return failure();
703 if (failed(parser.parseColon()))
704 return failure();
705 if (failed(parser.parseType(accVarType)))
706 return failure();
707 if (failed(parser.parseRParen()))
708 return failure();
709
710 return success();
711}
712
714 mlir::Value accVar, mlir::Type accVarType) {
715 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
716 p << "accPtr(";
717 else
718 p << "accVar(";
719 p.printOperand(accVar);
720 p << " : ";
721 p.printType(accVarType);
722 p << ")";
723}
724
725static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
726 mlir::Type &varPtrType,
727 mlir::TypeAttr &varTypeAttr) {
728 if (failed(parser.parseType(varPtrType)))
729 return failure();
730 if (failed(parser.parseRParen()))
731 return failure();
732
733 if (succeeded(parser.parseOptionalKeyword("varType"))) {
734 if (failed(parser.parseLParen()))
735 return failure();
736 mlir::Type varType;
737 if (failed(parser.parseType(varType)))
738 return failure();
739 varTypeAttr = mlir::TypeAttr::get(varType);
740 if (failed(parser.parseRParen()))
741 return failure();
742 } else {
743 // Set `varType` from the element type of the type of `varPtr`.
744 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
745 varTypeAttr = mlir::TypeAttr::get(
746 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
747 else
748 varTypeAttr = mlir::TypeAttr::get(varPtrType);
749 }
750
751 return success();
752}
753
755 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
756 p.printType(varPtrType);
757 p << ")";
758
759 // Print the `varType` only if it differs from the element type of
760 // `varPtr`'s type.
761 mlir::Type varType = varTypeAttr.getValue();
762 mlir::Type typeToCheckAgainst =
763 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
764 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
765 : varPtrType;
766 if (typeToCheckAgainst != varType) {
767 p << " varType(";
768 p.printType(varType);
769 p << ")";
770 }
771}
772
773static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
774 mlir::SymbolRefAttr &recipeAttr) {
775 if (failed(parser.parseAttribute(recipeAttr)))
776 return failure();
777 return success();
778}
779
781 mlir::SymbolRefAttr recipeAttr) {
782 p << recipeAttr;
783}
784
785//===----------------------------------------------------------------------===//
786// DataBoundsOp
787//===----------------------------------------------------------------------===//
788LogicalResult acc::DataBoundsOp::verify() {
789 auto extent = getExtent();
790 auto upperbound = getUpperbound();
791 if (!extent && !upperbound)
792 return emitError("expected extent or upperbound.");
793 return success();
794}
795
796//===----------------------------------------------------------------------===//
797// PrivateOp
798//===----------------------------------------------------------------------===//
799LogicalResult acc::PrivateOp::verify() {
800 if (getDataClause() != acc::DataClause::acc_private)
801 return emitError(
802 "data clause associated with private operation must match its intent");
803 if (failed(checkVarAndVarType(*this)))
804 return failure();
805 if (failed(checkNoModifier(*this)))
806 return failure();
807 if (failed(
809 return failure();
810 return success();
811}
812
813//===----------------------------------------------------------------------===//
814// FirstprivateOp
815//===----------------------------------------------------------------------===//
816LogicalResult acc::FirstprivateOp::verify() {
817 if (getDataClause() != acc::DataClause::acc_firstprivate)
818 return emitError("data clause associated with firstprivate operation must "
819 "match its intent");
820 if (failed(checkVarAndVarType(*this)))
821 return failure();
822 if (failed(checkNoModifier(*this)))
823 return failure();
825 *this, "firstprivate")))
826 return failure();
827 return success();
828}
829
830//===----------------------------------------------------------------------===//
831// FirstprivateMapInitialOp
832//===----------------------------------------------------------------------===//
833LogicalResult acc::FirstprivateMapInitialOp::verify() {
834 if (getDataClause() != acc::DataClause::acc_firstprivate)
835 return emitError("data clause associated with firstprivate operation must "
836 "match its intent");
837 if (failed(checkVarAndVarType(*this)))
838 return failure();
839 if (failed(checkNoModifier(*this)))
840 return failure();
841 return success();
842}
843
844//===----------------------------------------------------------------------===//
845// ReductionOp
846//===----------------------------------------------------------------------===//
847LogicalResult acc::ReductionOp::verify() {
848 if (getDataClause() != acc::DataClause::acc_reduction)
849 return emitError("data clause associated with reduction operation must "
850 "match its intent");
851 if (failed(checkVarAndVarType(*this)))
852 return failure();
853 if (failed(checkNoModifier(*this)))
854 return failure();
856 *this, "reduction")))
857 return failure();
858 return success();
859}
860
861//===----------------------------------------------------------------------===//
862// DevicePtrOp
863//===----------------------------------------------------------------------===//
864LogicalResult acc::DevicePtrOp::verify() {
865 if (getDataClause() != acc::DataClause::acc_deviceptr)
866 return emitError("data clause associated with deviceptr operation must "
867 "match its intent");
868 if (failed(checkVarAndVarType(*this)))
869 return failure();
870 if (failed(checkVarAndAccVar(*this)))
871 return failure();
872 if (failed(checkNoModifier(*this)))
873 return failure();
874 return success();
875}
876
877//===----------------------------------------------------------------------===//
878// PresentOp
879//===----------------------------------------------------------------------===//
880LogicalResult acc::PresentOp::verify() {
881 if (getDataClause() != acc::DataClause::acc_present)
882 return emitError(
883 "data clause associated with present operation must match its intent");
884 if (failed(checkVarAndVarType(*this)))
885 return failure();
886 if (failed(checkVarAndAccVar(*this)))
887 return failure();
888 if (failed(checkNoModifier(*this)))
889 return failure();
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// CopyinOp
895//===----------------------------------------------------------------------===//
896LogicalResult acc::CopyinOp::verify() {
897 // Test for all clauses this operation can be decomposed from:
898 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
899 getDataClause() != acc::DataClause::acc_copyin_readonly &&
900 getDataClause() != acc::DataClause::acc_copy &&
901 getDataClause() != acc::DataClause::acc_reduction)
902 return emitError(
903 "data clause associated with copyin operation must match its intent"
904 " or specify original clause this operation was decomposed from");
905 if (failed(checkVarAndVarType(*this)))
906 return failure();
907 if (failed(checkVarAndAccVar(*this)))
908 return failure();
909 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
910 acc::DataClauseModifier::always |
911 acc::DataClauseModifier::capture)))
912 return failure();
913 return success();
914}
915
916bool acc::CopyinOp::isCopyinReadonly() {
917 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
918 acc::bitEnumContainsAny(getModifiers(),
919 acc::DataClauseModifier::readonly);
920}
921
922//===----------------------------------------------------------------------===//
923// CreateOp
924//===----------------------------------------------------------------------===//
925LogicalResult acc::CreateOp::verify() {
926 // Test for all clauses this operation can be decomposed from:
927 if (getDataClause() != acc::DataClause::acc_create &&
928 getDataClause() != acc::DataClause::acc_create_zero &&
929 getDataClause() != acc::DataClause::acc_copyout &&
930 getDataClause() != acc::DataClause::acc_copyout_zero)
931 return emitError(
932 "data clause associated with create operation must match its intent"
933 " or specify original clause this operation was decomposed from");
934 if (failed(checkVarAndVarType(*this)))
935 return failure();
936 if (failed(checkVarAndAccVar(*this)))
937 return failure();
938 // this op is the entry part of copyout, so it also needs to allow all
939 // modifiers allowed on copyout.
940 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
941 acc::DataClauseModifier::always |
942 acc::DataClauseModifier::capture)))
943 return failure();
944 return success();
945}
946
947bool acc::CreateOp::isCreateZero() {
948 // The zero modifier is encoded in the data clause.
949 return getDataClause() == acc::DataClause::acc_create_zero ||
950 getDataClause() == acc::DataClause::acc_copyout_zero ||
951 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
952}
953
954//===----------------------------------------------------------------------===//
955// NoCreateOp
956//===----------------------------------------------------------------------===//
957LogicalResult acc::NoCreateOp::verify() {
958 if (getDataClause() != acc::DataClause::acc_no_create)
959 return emitError("data clause associated with no_create operation must "
960 "match its intent");
961 if (failed(checkVarAndVarType(*this)))
962 return failure();
963 if (failed(checkVarAndAccVar(*this)))
964 return failure();
965 if (failed(checkNoModifier(*this)))
966 return failure();
967 return success();
968}
969
970//===----------------------------------------------------------------------===//
971// AttachOp
972//===----------------------------------------------------------------------===//
973LogicalResult acc::AttachOp::verify() {
974 if (getDataClause() != acc::DataClause::acc_attach)
975 return emitError(
976 "data clause associated with attach operation must match its intent");
977 if (failed(checkVarAndVarType(*this)))
978 return failure();
979 if (failed(checkVarAndAccVar(*this)))
980 return failure();
981 if (failed(checkNoModifier(*this)))
982 return failure();
983 return success();
984}
985
986//===----------------------------------------------------------------------===//
987// DeclareDeviceResidentOp
988//===----------------------------------------------------------------------===//
989
990LogicalResult acc::DeclareDeviceResidentOp::verify() {
991 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
992 return emitError("data clause associated with device_resident operation "
993 "must match its intent");
994 if (failed(checkVarAndVarType(*this)))
995 return failure();
996 if (failed(checkVarAndAccVar(*this)))
997 return failure();
998 if (failed(checkNoModifier(*this)))
999 return failure();
1000 return success();
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// DeclareLinkOp
1005//===----------------------------------------------------------------------===//
1006
1007LogicalResult acc::DeclareLinkOp::verify() {
1008 if (getDataClause() != acc::DataClause::acc_declare_link)
1009 return emitError(
1010 "data clause associated with link operation must match its intent");
1011 if (failed(checkVarAndVarType(*this)))
1012 return failure();
1013 if (failed(checkVarAndAccVar(*this)))
1014 return failure();
1015 if (failed(checkNoModifier(*this)))
1016 return failure();
1017 return success();
1018}
1019
1020//===----------------------------------------------------------------------===//
1021// CopyoutOp
1022//===----------------------------------------------------------------------===//
1023LogicalResult acc::CopyoutOp::verify() {
1024 // Test for all clauses this operation can be decomposed from:
1025 if (getDataClause() != acc::DataClause::acc_copyout &&
1026 getDataClause() != acc::DataClause::acc_copyout_zero &&
1027 getDataClause() != acc::DataClause::acc_copy &&
1028 getDataClause() != acc::DataClause::acc_reduction)
1029 return emitError(
1030 "data clause associated with copyout operation must match its intent"
1031 " or specify original clause this operation was decomposed from");
1032 if (!getVar() || !getAccVar())
1033 return emitError("must have both host and device pointers");
1034 if (failed(checkVarAndVarType(*this)))
1035 return failure();
1036 if (failed(checkVarAndAccVar(*this)))
1037 return failure();
1038 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1039 acc::DataClauseModifier::always |
1040 acc::DataClauseModifier::capture)))
1041 return failure();
1042 return success();
1043}
1044
1045bool acc::CopyoutOp::isCopyoutZero() {
1046 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1047 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// DeleteOp
1052//===----------------------------------------------------------------------===//
1053LogicalResult acc::DeleteOp::verify() {
1054 // Test for all clauses this operation can be decomposed from:
1055 if (getDataClause() != acc::DataClause::acc_delete &&
1056 getDataClause() != acc::DataClause::acc_create &&
1057 getDataClause() != acc::DataClause::acc_create_zero &&
1058 getDataClause() != acc::DataClause::acc_copyin &&
1059 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1060 getDataClause() != acc::DataClause::acc_present &&
1061 getDataClause() != acc::DataClause::acc_no_create &&
1062 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1063 getDataClause() != acc::DataClause::acc_declare_link)
1064 return emitError(
1065 "data clause associated with delete operation must match its intent"
1066 " or specify original clause this operation was decomposed from");
1067 if (!getAccVar())
1068 return emitError("must have device pointer");
1069 // This op is the exit part of copyin and create - thus allow all modifiers
1070 // allowed on either case.
1071 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1072 acc::DataClauseModifier::readonly |
1073 acc::DataClauseModifier::always |
1074 acc::DataClauseModifier::capture)))
1075 return failure();
1076 return success();
1077}
1078
1079//===----------------------------------------------------------------------===//
1080// DetachOp
1081//===----------------------------------------------------------------------===//
1082LogicalResult acc::DetachOp::verify() {
1083 // Test for all clauses this operation can be decomposed from:
1084 if (getDataClause() != acc::DataClause::acc_detach &&
1085 getDataClause() != acc::DataClause::acc_attach)
1086 return emitError(
1087 "data clause associated with detach operation must match its intent"
1088 " or specify original clause this operation was decomposed from");
1089 if (!getAccVar())
1090 return emitError("must have device pointer");
1091 if (failed(checkNoModifier(*this)))
1092 return failure();
1093 return success();
1094}
1095
1096//===----------------------------------------------------------------------===//
1097// HostOp
1098//===----------------------------------------------------------------------===//
1099LogicalResult acc::UpdateHostOp::verify() {
1100 // Test for all clauses this operation can be decomposed from:
1101 if (getDataClause() != acc::DataClause::acc_update_host &&
1102 getDataClause() != acc::DataClause::acc_update_self)
1103 return emitError(
1104 "data clause associated with host operation must match its intent"
1105 " or specify original clause this operation was decomposed from");
1106 if (!getVar() || !getAccVar())
1107 return emitError("must have both host and device pointers");
1108 if (failed(checkVarAndVarType(*this)))
1109 return failure();
1110 if (failed(checkVarAndAccVar(*this)))
1111 return failure();
1112 if (failed(checkNoModifier(*this)))
1113 return failure();
1114 return success();
1115}
1116
1117//===----------------------------------------------------------------------===//
1118// DeviceOp
1119//===----------------------------------------------------------------------===//
1120LogicalResult acc::UpdateDeviceOp::verify() {
1121 // Test for all clauses this operation can be decomposed from:
1122 if (getDataClause() != acc::DataClause::acc_update_device)
1123 return emitError(
1124 "data clause associated with device operation must match its intent"
1125 " or specify original clause this operation was decomposed from");
1126 if (failed(checkVarAndVarType(*this)))
1127 return failure();
1128 if (failed(checkVarAndAccVar(*this)))
1129 return failure();
1130 if (failed(checkNoModifier(*this)))
1131 return failure();
1132 return success();
1133}
1134
1135//===----------------------------------------------------------------------===//
1136// UseDeviceOp
1137//===----------------------------------------------------------------------===//
1138LogicalResult acc::UseDeviceOp::verify() {
1139 // Test for all clauses this operation can be decomposed from:
1140 if (getDataClause() != acc::DataClause::acc_use_device)
1141 return emitError(
1142 "data clause associated with use_device operation must match its intent"
1143 " or specify original clause this operation was decomposed from");
1144 if (failed(checkVarAndVarType(*this)))
1145 return failure();
1146 if (failed(checkVarAndAccVar(*this)))
1147 return failure();
1148 if (failed(checkNoModifier(*this)))
1149 return failure();
1150 return success();
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// CacheOp
1155//===----------------------------------------------------------------------===//
1156LogicalResult acc::CacheOp::verify() {
1157 // Test for all clauses this operation can be decomposed from:
1158 if (getDataClause() != acc::DataClause::acc_cache &&
1159 getDataClause() != acc::DataClause::acc_cache_readonly)
1160 return emitError(
1161 "data clause associated with cache operation must match its intent"
1162 " or specify original clause this operation was decomposed from");
1163 if (failed(checkVarAndVarType(*this)))
1164 return failure();
1165 if (failed(checkVarAndAccVar(*this)))
1166 return failure();
1167 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1168 return failure();
1169 return success();
1170}
1171
1172bool acc::CacheOp::isCacheReadonly() {
1173 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1174 acc::bitEnumContainsAny(getModifiers(),
1175 acc::DataClauseModifier::readonly);
1176}
1177
1178template <typename StructureOp>
1179static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1180 unsigned nRegions = 1) {
1181
1183 for (unsigned i = 0; i < nRegions; ++i)
1184 regions.push_back(state.addRegion());
1185
1186 for (Region *region : regions)
1187 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1188 return failure();
1189
1190 return success();
1191}
1192
1194 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1195}
1196
1197namespace {
1198/// Pattern to remove operation without region that have constant false `ifCond`
1199/// and remove the condition from the operation if the `ifCond` is a true
1200/// constant.
1201template <typename OpTy>
1202struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1203 using OpRewritePattern<OpTy>::OpRewritePattern;
1204
1205 LogicalResult matchAndRewrite(OpTy op,
1206 PatternRewriter &rewriter) const override {
1207 // Early return if there is no condition.
1208 Value ifCond = op.getIfCond();
1209 if (!ifCond)
1210 return failure();
1211
1212 IntegerAttr constAttr;
1213 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1214 return failure();
1215 if (constAttr.getInt())
1216 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1217 else
1218 rewriter.eraseOp(op);
1219
1220 return success();
1221 }
1222};
1223
1224/// Replaces the given op with the contents of the given single-block region,
1225/// using the operands of the block terminator to replace operation results.
1226static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1227 Region &region, ValueRange blockArgs = {}) {
1228 assert(region.hasOneBlock() && "expected single-block region");
1229 Block *block = &region.front();
1230 Operation *terminator = block->getTerminator();
1231 ValueRange results = terminator->getOperands();
1232 rewriter.inlineBlockBefore(block, op, blockArgs);
1233 rewriter.replaceOp(op, results);
1234 rewriter.eraseOp(terminator);
1235}
1236
1237/// Pattern to remove operation with region that have constant false `ifCond`
1238/// and remove the condition from the operation if the `ifCond` is constant
1239/// true.
1240template <typename OpTy>
1241struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1242 using OpRewritePattern<OpTy>::OpRewritePattern;
1243
1244 LogicalResult matchAndRewrite(OpTy op,
1245 PatternRewriter &rewriter) const override {
1246 // Early return if there is no condition.
1247 Value ifCond = op.getIfCond();
1248 if (!ifCond)
1249 return failure();
1250
1251 IntegerAttr constAttr;
1252 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1253 return failure();
1254 if (constAttr.getInt())
1255 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1256 else
1257 replaceOpWithRegion(rewriter, op, op.getRegion());
1258
1259 return success();
1260 }
1261};
1262
1263/// Remove empty acc.kernel_environment operations. If the operation has wait
1264/// operands, create a acc.wait operation to preserve synchronization.
1265struct RemoveEmptyKernelEnvironment
1266 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1267 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1268
1269 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1270 PatternRewriter &rewriter) const override {
1271 assert(op->getNumRegions() == 1 && "expected op to have one region");
1272
1273 Block &block = op.getRegion().front();
1274 if (!block.empty())
1275 return failure();
1276
1277 // Conservatively disable canonicalization of empty acc.kernel_environment
1278 // operations if the wait operands in the kernel_environment cannot be fully
1279 // represented by acc.wait operation.
1280
1281 // Disable canonicalization if device type is not the default
1282 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1283 for (auto attr : deviceTypeAttr) {
1284 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1285 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1286 return failure();
1287 }
1288 }
1289 }
1290
1291 // Disable canonicalization if any wait segment has a devnum
1292 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1293 for (auto attr : hasDevnumAttr) {
1294 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1295 if (boolAttr.getValue())
1296 return failure();
1297 }
1298 }
1299 }
1300
1301 // Disable canonicalization if there are multiple wait segments
1302 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1303 if (segmentsAttr.size() > 1)
1304 return failure();
1305 }
1306
1307 // Remove empty kernel environment.
1308 // Preserve synchronization by creating acc.wait operation if needed.
1309 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1310 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1311 /*asyncOperand=*/Value(),
1312 /*waitDevnum=*/Value(),
1313 /*async=*/nullptr,
1314 /*ifCond=*/Value());
1315 else
1316 rewriter.eraseOp(op);
1317
1318 return success();
1319 }
1320};
1321
1322//===----------------------------------------------------------------------===//
1323// Recipe Region Helpers
1324//===----------------------------------------------------------------------===//
1325
1326/// Create and populate an init region for privatization recipes.
1327/// Returns success if the region is populated, failure otherwise.
1328/// Sets needsFree to indicate if the allocated memory requires deallocation.
1329static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1330 Region &initRegion, Type varType,
1331 StringRef varName, ValueRange bounds,
1332 bool &needsFree) {
1333 // Create init block with arguments: original value + bounds
1334 SmallVector<Type> argTypes{varType};
1335 SmallVector<Location> argLocs{loc};
1336 for (Value bound : bounds) {
1337 argTypes.push_back(bound.getType());
1338 argLocs.push_back(loc);
1339 }
1340
1341 Block *initBlock = builder.createBlock(&initRegion);
1342 initBlock->addArguments(argTypes, argLocs);
1343 builder.setInsertionPointToStart(initBlock);
1344
1345 Value privatizedValue;
1346
1347 // Get the block argument that represents the original variable
1348 Value blockArgVar = initBlock->getArgument(0);
1349
1350 // Generate init region body based on variable type
1351 if (isa<MappableType>(varType)) {
1352 auto mappableTy = cast<MappableType>(varType);
1353 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1354 privatizedValue = mappableTy.generatePrivateInit(
1355 builder, loc, typedVar, varName, bounds, {}, needsFree);
1356 if (!privatizedValue)
1357 return failure();
1358 } else {
1359 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1360 auto pointerLikeTy = cast<PointerLikeType>(varType);
1361 // Use PointerLikeType's allocation API with the block argument
1362 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1363 blockArgVar, needsFree);
1364 if (!privatizedValue)
1365 return failure();
1366 }
1367
1368 // Add yield operation to init block
1369 acc::YieldOp::create(builder, loc, privatizedValue);
1370
1371 return success();
1372}
1373
1374/// Create and populate a copy region for firstprivate recipes.
1375/// Returns success if the region is populated, failure otherwise.
1376/// TODO: Handle MappableType - it does not yet have a copy API.
1377static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1378 Region &copyRegion, Type varType,
1379 ValueRange bounds) {
1380 // Create copy block with arguments: original value + privatized value +
1381 // bounds
1382 SmallVector<Type> copyArgTypes{varType, varType};
1383 SmallVector<Location> copyArgLocs{loc, loc};
1384 for (Value bound : bounds) {
1385 copyArgTypes.push_back(bound.getType());
1386 copyArgLocs.push_back(loc);
1387 }
1388
1389 Block *copyBlock = builder.createBlock(&copyRegion);
1390 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1391 builder.setInsertionPointToStart(copyBlock);
1392
1393 bool isMappable = isa<MappableType>(varType);
1394 bool isPointerLike = isa<PointerLikeType>(varType);
1395 // TODO: Handle MappableType - it does not yet have a copy API.
1396 // Otherwise, for now just fallback to pointer-like behavior.
1397 if (isMappable && !isPointerLike)
1398 return failure();
1399
1400 // Generate copy region body based on variable type
1401 if (isPointerLike) {
1402 auto pointerLikeTy = cast<PointerLikeType>(varType);
1403 Value originalArg = copyBlock->getArgument(0);
1404 Value privatizedArg = copyBlock->getArgument(1);
1405
1406 // Generate copy operation using PointerLikeType interface
1407 if (!pointerLikeTy.genCopy(
1408 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1409 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1410 return failure();
1411 }
1412
1413 // Add terminator to copy block
1414 acc::TerminatorOp::create(builder, loc);
1415
1416 return success();
1417}
1418
1419/// Create and populate a destroy region for privatization recipes.
1420/// Returns success if the region is populated, failure otherwise.
1421static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1422 Region &destroyRegion, Type varType,
1423 Value allocRes, ValueRange bounds) {
1424 // Create destroy block with arguments: original value + privatized value +
1425 // bounds
1426 SmallVector<Type> destroyArgTypes{varType, varType};
1427 SmallVector<Location> destroyArgLocs{loc, loc};
1428 for (Value bound : bounds) {
1429 destroyArgTypes.push_back(bound.getType());
1430 destroyArgLocs.push_back(loc);
1431 }
1432
1433 Block *destroyBlock = builder.createBlock(&destroyRegion);
1434 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1435 builder.setInsertionPointToStart(destroyBlock);
1436
1437 auto varToFree =
1438 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1439 if (isa<MappableType>(varType)) {
1440 auto mappableTy = cast<MappableType>(varType);
1441 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1442 return failure();
1443 } else {
1444 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1445 auto pointerLikeTy = cast<PointerLikeType>(varType);
1446 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1447 return failure();
1448 }
1449
1450 acc::TerminatorOp::create(builder, loc);
1451 return success();
1452}
1453
1454} // namespace
1455
1456//===----------------------------------------------------------------------===//
1457// PrivateRecipeOp
1458//===----------------------------------------------------------------------===//
1459
1461 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1462 Type type, bool verifyYield, bool optional = false) {
1463 if (optional && region.empty())
1464 return success();
1465
1466 if (region.empty())
1467 return op->emitOpError() << "expects non-empty " << regionName << " region";
1468 Block &firstBlock = region.front();
1469 if (firstBlock.getNumArguments() < 1 ||
1470 firstBlock.getArgument(0).getType() != type)
1471 return op->emitOpError() << "expects " << regionName
1472 << " region first "
1473 "argument of the "
1474 << regionType << " type";
1475
1476 if (verifyYield) {
1477 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1478 if (yieldOp.getOperands().size() != 1 ||
1479 yieldOp.getOperands().getTypes()[0] != type)
1480 return op->emitOpError() << "expects " << regionName
1481 << " region to "
1482 "yield a value of the "
1483 << regionType << " type";
1484 }
1485 }
1486 return success();
1487}
1488
1489LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1490 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1491 "privatization", "init", getType(),
1492 /*verifyYield=*/false)))
1493 return failure();
1495 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1496 /*verifyYield=*/false, /*optional=*/true)))
1497 return failure();
1498 return success();
1499}
1500
1501std::optional<PrivateRecipeOp>
1502PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1503 StringRef recipeName, Type varType,
1504 StringRef varName, ValueRange bounds) {
1505 // First, validate that we can handle this variable type
1506 bool isMappable = isa<MappableType>(varType);
1507 bool isPointerLike = isa<PointerLikeType>(varType);
1508
1509 // Unsupported type
1510 if (!isMappable && !isPointerLike)
1511 return std::nullopt;
1512
1513 OpBuilder::InsertionGuard guard(builder);
1514
1515 // Create the recipe operation first so regions have proper parent context
1516 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1517
1518 // Populate the init region
1519 bool needsFree = false;
1520 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1521 varName, bounds, needsFree))) {
1522 recipe.erase();
1523 return std::nullopt;
1524 }
1525
1526 // Only create destroy region if the allocation needs deallocation
1527 if (needsFree) {
1528 // Extract the allocated value from the init block's yield operation
1529 auto yieldOp =
1530 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1531 Value allocRes = yieldOp.getOperand(0);
1532
1533 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1534 varType, allocRes, bounds))) {
1535 recipe.erase();
1536 return std::nullopt;
1537 }
1538 }
1539
1540 return recipe;
1541}
1542
1543std::optional<PrivateRecipeOp>
1544PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1545 StringRef recipeName,
1546 FirstprivateRecipeOp firstprivRecipe) {
1547 // Create the private.recipe op with the same type as the firstprivate.recipe.
1548 OpBuilder::InsertionGuard guard(builder);
1549 auto varType = firstprivRecipe.getType();
1550 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1551
1552 // Clone the init region
1553 IRMapping mapping;
1554 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1555
1556 // Clone destroy region if the firstprivate.recipe has one.
1557 if (!firstprivRecipe.getDestroyRegion().empty()) {
1558 IRMapping mapping;
1559 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1560 mapping);
1561 }
1562 return recipe;
1563}
1564
1565//===----------------------------------------------------------------------===//
1566// FirstprivateRecipeOp
1567//===----------------------------------------------------------------------===//
1568
1569LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1570 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1571 "privatization", "init", getType(),
1572 /*verifyYield=*/false)))
1573 return failure();
1574
1575 if (getCopyRegion().empty())
1576 return emitOpError() << "expects non-empty copy region";
1577
1578 Block &firstBlock = getCopyRegion().front();
1579 if (firstBlock.getNumArguments() < 2 ||
1580 firstBlock.getArgument(0).getType() != getType())
1581 return emitOpError() << "expects copy region with two arguments of the "
1582 "privatization type";
1583
1584 if (getDestroyRegion().empty())
1585 return success();
1586
1587 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1588 "privatization", "destroy",
1589 getType(), /*verifyYield=*/false)))
1590 return failure();
1591
1592 return success();
1593}
1594
1595std::optional<FirstprivateRecipeOp>
1596FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1597 StringRef recipeName, Type varType,
1598 StringRef varName, ValueRange bounds) {
1599 // First, validate that we can handle this variable type
1600 bool isMappable = isa<MappableType>(varType);
1601 bool isPointerLike = isa<PointerLikeType>(varType);
1602
1603 // Unsupported type
1604 if (!isMappable && !isPointerLike)
1605 return std::nullopt;
1606
1607 OpBuilder::InsertionGuard guard(builder);
1608
1609 // Create the recipe operation first so regions have proper parent context
1610 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1611
1612 // Populate the init region
1613 bool needsFree = false;
1614 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1615 varName, bounds, needsFree))) {
1616 recipe.erase();
1617 return std::nullopt;
1618 }
1619
1620 // Populate the copy region
1621 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1622 bounds))) {
1623 recipe.erase();
1624 return std::nullopt;
1625 }
1626
1627 // Only create destroy region if the allocation needs deallocation
1628 if (needsFree) {
1629 // Extract the allocated value from the init block's yield operation
1630 auto yieldOp =
1631 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1632 Value allocRes = yieldOp.getOperand(0);
1633
1634 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1635 varType, allocRes, bounds))) {
1636 recipe.erase();
1637 return std::nullopt;
1638 }
1639 }
1640
1641 return recipe;
1642}
1643
1644//===----------------------------------------------------------------------===//
1645// ReductionRecipeOp
1646//===----------------------------------------------------------------------===//
1647
1648LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1649 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1650 "init", getType(),
1651 /*verifyYield=*/false)))
1652 return failure();
1653
1654 if (getCombinerRegion().empty())
1655 return emitOpError() << "expects non-empty combiner region";
1656
1657 Block &reductionBlock = getCombinerRegion().front();
1658 if (reductionBlock.getNumArguments() < 2 ||
1659 reductionBlock.getArgument(0).getType() != getType() ||
1660 reductionBlock.getArgument(1).getType() != getType())
1661 return emitOpError() << "expects combiner region with the first two "
1662 << "arguments of the reduction type";
1663
1664 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1665 if (yieldOp.getOperands().size() != 1 ||
1666 yieldOp.getOperands().getTypes()[0] != getType())
1667 return emitOpError() << "expects combiner region to yield a value "
1668 "of the reduction type";
1669 }
1670
1671 return success();
1672}
1673
1674//===----------------------------------------------------------------------===//
1675// ParallelOp
1676//===----------------------------------------------------------------------===//
1677
1678/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1679template <typename Op>
1680static LogicalResult checkDataOperands(Op op,
1681 const mlir::ValueRange &operands) {
1682 for (mlir::Value operand : operands)
1683 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1684 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1685 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1686 operand.getDefiningOp()))
1687 return op.emitError(
1688 "expect data entry/exit operation or acc.getdeviceptr "
1689 "as defining op");
1690 return success();
1691}
1692
1693template <typename OpT, typename RecipeOpT>
1694static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1695 const mlir::ValueRange &operands,
1696 llvm::StringRef operandName) {
1698 for (mlir::Value operand : operands) {
1699 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1700 return accConstructOp->emitOpError()
1701 << "expected " << operandName << " as defining op";
1702 if (!set.insert(operand).second)
1703 return accConstructOp->emitOpError()
1704 << operandName << " operand appears more than once";
1705 }
1706 return success();
1707}
1708
1709unsigned ParallelOp::getNumDataOperands() {
1710 return getReductionOperands().size() + getPrivateOperands().size() +
1711 getFirstprivateOperands().size() + getDataClauseOperands().size();
1712}
1713
1714Value ParallelOp::getDataOperand(unsigned i) {
1715 unsigned numOptional = getAsyncOperands().size();
1716 numOptional += getNumGangs().size();
1717 numOptional += getNumWorkers().size();
1718 numOptional += getVectorLength().size();
1719 numOptional += getIfCond() ? 1 : 0;
1720 numOptional += getSelfCond() ? 1 : 0;
1721 return getOperand(getWaitOperands().size() + numOptional + i);
1722}
1723
1724template <typename Op>
1725static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1726 ArrayAttr deviceTypes,
1727 llvm::StringRef keyword) {
1728 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1729 return op.emitOpError() << keyword << " operands count must match "
1730 << keyword << " device_type count";
1731 return success();
1732}
1733
1734template <typename Op>
1736 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1737 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1738 std::size_t numOperandsInSegments = 0;
1739 std::size_t nbOfSegments = 0;
1740
1741 if (segments) {
1742 for (auto segCount : segments.asArrayRef()) {
1743 if (maxInSegment != 0 && segCount > maxInSegment)
1744 return op.emitOpError() << keyword << " expects a maximum of "
1745 << maxInSegment << " values per segment";
1746 numOperandsInSegments += segCount;
1747 ++nbOfSegments;
1748 }
1749 }
1750
1751 if ((numOperandsInSegments != operands.size()) ||
1752 (!deviceTypes && !operands.empty()))
1753 return op.emitOpError()
1754 << keyword << " operand count does not match count in segments";
1755 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1756 return op.emitOpError()
1757 << keyword << " segment count does not match device_type count";
1758 return success();
1759}
1760
1761LogicalResult acc::ParallelOp::verify() {
1762 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1763 mlir::acc::PrivateRecipeOp>(
1764 *this, getPrivateOperands(), "private")))
1765 return failure();
1766 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1767 mlir::acc::FirstprivateRecipeOp>(
1768 *this, getFirstprivateOperands(), "firstprivate")))
1769 return failure();
1770 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1771 mlir::acc::ReductionRecipeOp>(
1772 *this, getReductionOperands(), "reduction")))
1773 return failure();
1774
1776 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1777 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1778 return failure();
1779
1781 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1782 getWaitOperandsDeviceTypeAttr(), "wait")))
1783 return failure();
1784
1785 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1786 getNumWorkersDeviceTypeAttr(),
1787 "num_workers")))
1788 return failure();
1789
1790 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1791 getVectorLengthDeviceTypeAttr(),
1792 "vector_length")))
1793 return failure();
1794
1796 getAsyncOperandsDeviceTypeAttr(),
1797 "async")))
1798 return failure();
1799
1801 return failure();
1802
1803 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1804}
1805
1806static mlir::Value
1807getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1809 mlir::acc::DeviceType deviceType) {
1810 if (!arrayAttr)
1811 return {};
1812 if (auto pos = findSegment(*arrayAttr, deviceType))
1813 return range[*pos];
1814 return {};
1815}
1816
1817bool acc::ParallelOp::hasAsyncOnly() {
1818 return hasAsyncOnly(mlir::acc::DeviceType::None);
1819}
1820
1821bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1822 return hasDeviceType(getAsyncOnly(), deviceType);
1823}
1824
1825mlir::Value acc::ParallelOp::getAsyncValue() {
1826 return getAsyncValue(mlir::acc::DeviceType::None);
1827}
1828
1829mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1831 getAsyncOperands(), deviceType);
1832}
1833
1834mlir::Value acc::ParallelOp::getNumWorkersValue() {
1835 return getNumWorkersValue(mlir::acc::DeviceType::None);
1836}
1837
1839acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1840 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1841 deviceType);
1842}
1843
1844mlir::Value acc::ParallelOp::getVectorLengthValue() {
1845 return getVectorLengthValue(mlir::acc::DeviceType::None);
1846}
1847
1849acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1850 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1851 getVectorLength(), deviceType);
1852}
1853
1854mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1855 return getNumGangsValues(mlir::acc::DeviceType::None);
1856}
1857
1859ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1860 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1861 getNumGangsSegments(), deviceType);
1862}
1863
1864bool acc::ParallelOp::hasWaitOnly() {
1865 return hasWaitOnly(mlir::acc::DeviceType::None);
1866}
1867
1868bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1869 return hasDeviceType(getWaitOnly(), deviceType);
1870}
1871
1872mlir::Operation::operand_range ParallelOp::getWaitValues() {
1873 return getWaitValues(mlir::acc::DeviceType::None);
1874}
1875
1877ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1879 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1880 getHasWaitDevnum(), deviceType);
1881}
1882
1883mlir::Value ParallelOp::getWaitDevnum() {
1884 return getWaitDevnum(mlir::acc::DeviceType::None);
1885}
1886
1887mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1888 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1889 getWaitOperandsSegments(), getHasWaitDevnum(),
1890 deviceType);
1891}
1892
1893void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1894 mlir::OperationState &odsState,
1895 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1896 mlir::ValueRange vectorLength,
1897 mlir::ValueRange asyncOperands,
1898 mlir::ValueRange waitOperands, mlir::Value ifCond,
1899 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1900 mlir::ValueRange gangPrivateOperands,
1901 mlir::ValueRange gangFirstPrivateOperands,
1902 mlir::ValueRange dataClauseOperands) {
1903 ParallelOp::build(
1904 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1905 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1906 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1907 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1908 /*numGangsDeviceType=*/nullptr, numWorkers,
1909 /*numWorkersDeviceType=*/nullptr, vectorLength,
1910 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1911 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
1912 gangFirstPrivateOperands, dataClauseOperands,
1913 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1914}
1915
1916void acc::ParallelOp::addNumWorkersOperand(
1917 MLIRContext *context, mlir::Value newValue,
1918 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1919 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1920 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1921 getNumWorkersMutable()));
1922}
1923void acc::ParallelOp::addVectorLengthOperand(
1924 MLIRContext *context, mlir::Value newValue,
1925 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1926 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1927 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1928 getVectorLengthMutable()));
1929}
1930
1931void acc::ParallelOp::addAsyncOnly(
1932 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1933 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1934 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1935}
1936
1937void acc::ParallelOp::addAsyncOperand(
1938 MLIRContext *context, mlir::Value newValue,
1939 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1940 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1941 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1942 getAsyncOperandsMutable()));
1943}
1944
1945void acc::ParallelOp::addNumGangsOperands(
1946 MLIRContext *context, mlir::ValueRange newValues,
1947 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1949 if (getNumGangsSegments())
1950 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1951
1952 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1953 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1954 getNumGangsMutable(), segments));
1955
1956 setNumGangsSegments(segments);
1957}
1958void acc::ParallelOp::addWaitOnly(
1959 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1960 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1961 effectiveDeviceTypes));
1962}
1963void acc::ParallelOp::addWaitOperands(
1964 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1965 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1966
1968 if (getWaitOperandsSegments())
1969 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1970
1971 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1972 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1973 getWaitOperandsMutable(), segments));
1974 setWaitOperandsSegments(segments);
1975
1977 if (getHasWaitDevnumAttr())
1978 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1979 hasDevnums.insert(
1980 hasDevnums.end(),
1981 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1982 mlir::BoolAttr::get(context, hasDevnum));
1983 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1984}
1985
1986void acc::ParallelOp::addPrivatization(MLIRContext *context,
1987 mlir::acc::PrivateOp op,
1988 mlir::acc::PrivateRecipeOp recipe) {
1989 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1990 getPrivateOperandsMutable().append(op.getResult());
1991}
1992
1993void acc::ParallelOp::addFirstPrivatization(
1994 MLIRContext *context, mlir::acc::FirstprivateOp op,
1995 mlir::acc::FirstprivateRecipeOp recipe) {
1996 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1997 getFirstprivateOperandsMutable().append(op.getResult());
1998}
1999
2000void acc::ParallelOp::addReduction(MLIRContext *context,
2001 mlir::acc::ReductionOp op,
2002 mlir::acc::ReductionRecipeOp recipe) {
2003 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2004 getReductionOperandsMutable().append(op.getResult());
2005}
2006
2007static ParseResult parseNumGangs(
2008 mlir::OpAsmParser &parser,
2010 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2011 mlir::DenseI32ArrayAttr &segments) {
2014
2015 do {
2016 if (failed(parser.parseLBrace()))
2017 return failure();
2018
2019 int32_t crtOperandsSize = operands.size();
2020 if (failed(parser.parseCommaSeparatedList(
2022 if (parser.parseOperand(operands.emplace_back()) ||
2023 parser.parseColonType(types.emplace_back()))
2024 return failure();
2025 return success();
2026 })))
2027 return failure();
2028 seg.push_back(operands.size() - crtOperandsSize);
2029
2030 if (failed(parser.parseRBrace()))
2031 return failure();
2032
2033 if (succeeded(parser.parseOptionalLSquare())) {
2034 if (parser.parseAttribute(attributes.emplace_back()) ||
2035 parser.parseRSquare())
2036 return failure();
2037 } else {
2038 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2039 parser.getContext(), mlir::acc::DeviceType::None));
2040 }
2041 } while (succeeded(parser.parseOptionalComma()));
2042
2043 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2044 attributes.end());
2045 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2046 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2047
2048 return success();
2049}
2050
2052 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2053 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2054 p << " [" << attr << "]";
2055}
2056
2058 mlir::OperandRange operands, mlir::TypeRange types,
2059 std::optional<mlir::ArrayAttr> deviceTypes,
2060 std::optional<mlir::DenseI32ArrayAttr> segments) {
2061 unsigned opIdx = 0;
2062 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2063 p << "{";
2064 llvm::interleaveComma(
2065 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2066 p << operands[opIdx] << " : " << operands[opIdx].getType();
2067 ++opIdx;
2068 });
2069 p << "}";
2070 printSingleDeviceType(p, it.value());
2071 });
2072}
2073
2075 mlir::OpAsmParser &parser,
2077 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2078 mlir::DenseI32ArrayAttr &segments) {
2081
2082 do {
2083 if (failed(parser.parseLBrace()))
2084 return failure();
2085
2086 int32_t crtOperandsSize = operands.size();
2087
2088 if (failed(parser.parseCommaSeparatedList(
2090 if (parser.parseOperand(operands.emplace_back()) ||
2091 parser.parseColonType(types.emplace_back()))
2092 return failure();
2093 return success();
2094 })))
2095 return failure();
2096
2097 seg.push_back(operands.size() - crtOperandsSize);
2098
2099 if (failed(parser.parseRBrace()))
2100 return failure();
2101
2102 if (succeeded(parser.parseOptionalLSquare())) {
2103 if (parser.parseAttribute(attributes.emplace_back()) ||
2104 parser.parseRSquare())
2105 return failure();
2106 } else {
2107 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2108 parser.getContext(), mlir::acc::DeviceType::None));
2109 }
2110 } while (succeeded(parser.parseOptionalComma()));
2111
2112 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2113 attributes.end());
2114 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2115 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2116
2117 return success();
2118}
2119
2122 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2123 std::optional<mlir::DenseI32ArrayAttr> segments) {
2124 unsigned opIdx = 0;
2125 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2126 p << "{";
2127 llvm::interleaveComma(
2128 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2129 p << operands[opIdx] << " : " << operands[opIdx].getType();
2130 ++opIdx;
2131 });
2132 p << "}";
2133 printSingleDeviceType(p, it.value());
2134 });
2135}
2136
2137static ParseResult parseWaitClause(
2138 mlir::OpAsmParser &parser,
2140 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2141 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2142 mlir::ArrayAttr &keywordOnly) {
2143 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2145
2146 bool needCommaBeforeOperands = false;
2147
2148 // Keyword only
2149 if (failed(parser.parseOptionalLParen())) {
2150 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2151 parser.getContext(), mlir::acc::DeviceType::None));
2152 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2153 return success();
2154 }
2155
2156 // Parse keyword only attributes
2157 if (succeeded(parser.parseOptionalLSquare())) {
2158 if (failed(parser.parseCommaSeparatedList([&]() {
2159 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2160 return failure();
2161 return success();
2162 })))
2163 return failure();
2164 if (parser.parseRSquare())
2165 return failure();
2166 needCommaBeforeOperands = true;
2167 }
2168
2169 if (needCommaBeforeOperands && failed(parser.parseComma()))
2170 return failure();
2171
2172 do {
2173 if (failed(parser.parseLBrace()))
2174 return failure();
2175
2176 int32_t crtOperandsSize = operands.size();
2177
2178 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2179 if (failed(parser.parseColon()))
2180 return failure();
2181 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2182 } else {
2183 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2184 }
2185
2186 if (failed(parser.parseCommaSeparatedList(
2188 if (parser.parseOperand(operands.emplace_back()) ||
2189 parser.parseColonType(types.emplace_back()))
2190 return failure();
2191 return success();
2192 })))
2193 return failure();
2194
2195 seg.push_back(operands.size() - crtOperandsSize);
2196
2197 if (failed(parser.parseRBrace()))
2198 return failure();
2199
2200 if (succeeded(parser.parseOptionalLSquare())) {
2201 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2202 parser.parseRSquare())
2203 return failure();
2204 } else {
2205 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2206 parser.getContext(), mlir::acc::DeviceType::None));
2207 }
2208 } while (succeeded(parser.parseOptionalComma()));
2209
2210 if (failed(parser.parseRParen()))
2211 return failure();
2212
2213 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2214 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2215 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2216 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2217
2218 return success();
2219}
2220
2221static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2222 if (!hasDeviceTypeValues(attrs))
2223 return false;
2224 if (attrs->size() != 1)
2225 return false;
2226 if (auto deviceTypeAttr =
2227 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2228 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2229 return false;
2230}
2231
2233 mlir::OperandRange operands, mlir::TypeRange types,
2234 std::optional<mlir::ArrayAttr> deviceTypes,
2235 std::optional<mlir::DenseI32ArrayAttr> segments,
2236 std::optional<mlir::ArrayAttr> hasDevNum,
2237 std::optional<mlir::ArrayAttr> keywordOnly) {
2238
2239 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2240 return;
2241
2242 p << "(";
2243
2244 printDeviceTypes(p, keywordOnly);
2245 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2246 p << ", ";
2247
2248 if (hasDeviceTypeValues(deviceTypes)) {
2249 unsigned opIdx = 0;
2250 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2251 p << "{";
2252 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2253 if (boolAttr && boolAttr.getValue())
2254 p << "devnum: ";
2255 llvm::interleaveComma(
2256 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2257 p << operands[opIdx] << " : " << operands[opIdx].getType();
2258 ++opIdx;
2259 });
2260 p << "}";
2261 printSingleDeviceType(p, it.value());
2262 });
2263 }
2264
2265 p << ")";
2266}
2267
2268static ParseResult parseDeviceTypeOperands(
2269 mlir::OpAsmParser &parser,
2271 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2273 if (failed(parser.parseCommaSeparatedList([&]() {
2274 if (parser.parseOperand(operands.emplace_back()) ||
2275 parser.parseColonType(types.emplace_back()))
2276 return failure();
2277 if (succeeded(parser.parseOptionalLSquare())) {
2278 if (parser.parseAttribute(attributes.emplace_back()) ||
2279 parser.parseRSquare())
2280 return failure();
2281 } else {
2282 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2283 parser.getContext(), mlir::acc::DeviceType::None));
2284 }
2285 return success();
2286 })))
2287 return failure();
2288 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2289 attributes.end());
2290 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2291 return success();
2292}
2293
2294static void
2296 mlir::OperandRange operands, mlir::TypeRange types,
2297 std::optional<mlir::ArrayAttr> deviceTypes) {
2298 if (!hasDeviceTypeValues(deviceTypes))
2299 return;
2300 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2301 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2302 printSingleDeviceType(p, std::get<0>(it));
2303 });
2304}
2305
2307 mlir::OpAsmParser &parser,
2309 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2310 mlir::ArrayAttr &keywordOnlyDeviceType) {
2311
2312 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2313 bool needCommaBeforeOperands = false;
2314
2315 if (failed(parser.parseOptionalLParen())) {
2316 // Keyword only
2317 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2318 parser.getContext(), mlir::acc::DeviceType::None));
2319 keywordOnlyDeviceType =
2320 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2321 return success();
2322 }
2323
2324 // Parse keyword only attributes
2325 if (succeeded(parser.parseOptionalLSquare())) {
2326 // Parse keyword only attributes
2327 if (failed(parser.parseCommaSeparatedList([&]() {
2328 if (parser.parseAttribute(
2329 keywordOnlyDeviceTypeAttributes.emplace_back()))
2330 return failure();
2331 return success();
2332 })))
2333 return failure();
2334 if (parser.parseRSquare())
2335 return failure();
2336 needCommaBeforeOperands = true;
2337 }
2338
2339 if (needCommaBeforeOperands && failed(parser.parseComma()))
2340 return failure();
2341
2343 if (failed(parser.parseCommaSeparatedList([&]() {
2344 if (parser.parseOperand(operands.emplace_back()) ||
2345 parser.parseColonType(types.emplace_back()))
2346 return failure();
2347 if (succeeded(parser.parseOptionalLSquare())) {
2348 if (parser.parseAttribute(attributes.emplace_back()) ||
2349 parser.parseRSquare())
2350 return failure();
2351 } else {
2352 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2353 parser.getContext(), mlir::acc::DeviceType::None));
2354 }
2355 return success();
2356 })))
2357 return failure();
2358
2359 if (failed(parser.parseRParen()))
2360 return failure();
2361
2362 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2363 attributes.end());
2364 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2365 return success();
2366}
2367
2370 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2371 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2372
2373 if (operands.begin() == operands.end() &&
2374 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2375 return;
2376 }
2377
2378 p << "(";
2379 printDeviceTypes(p, keywordOnlyDeviceTypes);
2380 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2381 hasDeviceTypeValues(deviceTypes))
2382 p << ", ";
2383 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2384 p << ")";
2385}
2386
2388 mlir::OpAsmParser &parser,
2389 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2390 mlir::Type &operandType, mlir::UnitAttr &attr) {
2391 // Keyword only
2392 if (failed(parser.parseOptionalLParen())) {
2393 attr = mlir::UnitAttr::get(parser.getContext());
2394 return success();
2395 }
2396
2398 if (failed(parser.parseOperand(op)))
2399 return failure();
2400 operand = op;
2401 if (failed(parser.parseColon()))
2402 return failure();
2403 if (failed(parser.parseType(operandType)))
2404 return failure();
2405 if (failed(parser.parseRParen()))
2406 return failure();
2407
2408 return success();
2409}
2410
2412 mlir::Operation *op,
2413 std::optional<mlir::Value> operand,
2414 mlir::Type operandType,
2415 mlir::UnitAttr attr) {
2416 if (attr)
2417 return;
2418
2419 p << "(";
2420 p.printOperand(*operand);
2421 p << " : ";
2422 p.printType(operandType);
2423 p << ")";
2424}
2425
2427 mlir::OpAsmParser &parser,
2429 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2430 // Keyword only
2431 if (failed(parser.parseOptionalLParen())) {
2432 attr = mlir::UnitAttr::get(parser.getContext());
2433 return success();
2434 }
2435
2436 if (failed(parser.parseCommaSeparatedList([&]() {
2437 if (parser.parseOperand(operands.emplace_back()))
2438 return failure();
2439 return success();
2440 })))
2441 return failure();
2442 if (failed(parser.parseColon()))
2443 return failure();
2444 if (failed(parser.parseCommaSeparatedList([&]() {
2445 if (parser.parseType(types.emplace_back()))
2446 return failure();
2447 return success();
2448 })))
2449 return failure();
2450 if (failed(parser.parseRParen()))
2451 return failure();
2452
2453 return success();
2454}
2455
2457 mlir::Operation *op,
2458 mlir::OperandRange operands,
2459 mlir::TypeRange types,
2460 mlir::UnitAttr attr) {
2461 if (attr)
2462 return;
2463
2464 p << "(";
2465 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2466 p << " : ";
2467 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2468 p << ")";
2469}
2470
2471static ParseResult
2473 mlir::acc::CombinedConstructsTypeAttr &attr) {
2474 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2475 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2476 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2477 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2478 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2479 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2480 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2481 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2482 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2483 } else {
2484 parser.emitError(parser.getCurrentLocation(),
2485 "expected compute construct name");
2486 return failure();
2487 }
2488 return success();
2489}
2490
2491static void
2493 mlir::acc::CombinedConstructsTypeAttr attr) {
2494 if (attr) {
2495 switch (attr.getValue()) {
2496 case mlir::acc::CombinedConstructsType::KernelsLoop:
2497 p << "kernels";
2498 break;
2499 case mlir::acc::CombinedConstructsType::ParallelLoop:
2500 p << "parallel";
2501 break;
2502 case mlir::acc::CombinedConstructsType::SerialLoop:
2503 p << "serial";
2504 break;
2505 };
2506 }
2507}
2508
2509//===----------------------------------------------------------------------===//
2510// SerialOp
2511//===----------------------------------------------------------------------===//
2512
2513unsigned SerialOp::getNumDataOperands() {
2514 return getReductionOperands().size() + getPrivateOperands().size() +
2515 getFirstprivateOperands().size() + getDataClauseOperands().size();
2516}
2517
2518Value SerialOp::getDataOperand(unsigned i) {
2519 unsigned numOptional = getAsyncOperands().size();
2520 numOptional += getIfCond() ? 1 : 0;
2521 numOptional += getSelfCond() ? 1 : 0;
2522 return getOperand(getWaitOperands().size() + numOptional + i);
2523}
2524
2525bool acc::SerialOp::hasAsyncOnly() {
2526 return hasAsyncOnly(mlir::acc::DeviceType::None);
2527}
2528
2529bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2530 return hasDeviceType(getAsyncOnly(), deviceType);
2531}
2532
2533mlir::Value acc::SerialOp::getAsyncValue() {
2534 return getAsyncValue(mlir::acc::DeviceType::None);
2535}
2536
2537mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2539 getAsyncOperands(), deviceType);
2540}
2541
2542bool acc::SerialOp::hasWaitOnly() {
2543 return hasWaitOnly(mlir::acc::DeviceType::None);
2544}
2545
2546bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2547 return hasDeviceType(getWaitOnly(), deviceType);
2548}
2549
2550mlir::Operation::operand_range SerialOp::getWaitValues() {
2551 return getWaitValues(mlir::acc::DeviceType::None);
2552}
2553
2555SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2557 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2558 getHasWaitDevnum(), deviceType);
2559}
2560
2561mlir::Value SerialOp::getWaitDevnum() {
2562 return getWaitDevnum(mlir::acc::DeviceType::None);
2563}
2564
2565mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2566 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2567 getWaitOperandsSegments(), getHasWaitDevnum(),
2568 deviceType);
2569}
2570
2571LogicalResult acc::SerialOp::verify() {
2572 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2573 mlir::acc::PrivateRecipeOp>(
2574 *this, getPrivateOperands(), "private")))
2575 return failure();
2576 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2577 mlir::acc::FirstprivateRecipeOp>(
2578 *this, getFirstprivateOperands(), "firstprivate")))
2579 return failure();
2580 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2581 mlir::acc::ReductionRecipeOp>(
2582 *this, getReductionOperands(), "reduction")))
2583 return failure();
2584
2586 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2587 getWaitOperandsDeviceTypeAttr(), "wait")))
2588 return failure();
2589
2591 getAsyncOperandsDeviceTypeAttr(),
2592 "async")))
2593 return failure();
2594
2596 return failure();
2597
2598 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2599}
2600
2601void acc::SerialOp::addAsyncOnly(
2602 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2603 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2604 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2605}
2606
2607void acc::SerialOp::addAsyncOperand(
2608 MLIRContext *context, mlir::Value newValue,
2609 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2610 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2611 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2612 getAsyncOperandsMutable()));
2613}
2614
2615void acc::SerialOp::addWaitOnly(
2616 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2617 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2618 effectiveDeviceTypes));
2619}
2620void acc::SerialOp::addWaitOperands(
2621 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2622 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2623
2625 if (getWaitOperandsSegments())
2626 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2627
2628 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2629 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2630 getWaitOperandsMutable(), segments));
2631 setWaitOperandsSegments(segments);
2632
2634 if (getHasWaitDevnumAttr())
2635 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2636 hasDevnums.insert(
2637 hasDevnums.end(),
2638 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2639 mlir::BoolAttr::get(context, hasDevnum));
2640 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2641}
2642
2643void acc::SerialOp::addPrivatization(MLIRContext *context,
2644 mlir::acc::PrivateOp op,
2645 mlir::acc::PrivateRecipeOp recipe) {
2646 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2647 getPrivateOperandsMutable().append(op.getResult());
2648}
2649
2650void acc::SerialOp::addFirstPrivatization(
2651 MLIRContext *context, mlir::acc::FirstprivateOp op,
2652 mlir::acc::FirstprivateRecipeOp recipe) {
2653 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2654 getFirstprivateOperandsMutable().append(op.getResult());
2655}
2656
2657void acc::SerialOp::addReduction(MLIRContext *context,
2658 mlir::acc::ReductionOp op,
2659 mlir::acc::ReductionRecipeOp recipe) {
2660 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2661 getReductionOperandsMutable().append(op.getResult());
2662}
2663
2664//===----------------------------------------------------------------------===//
2665// KernelsOp
2666//===----------------------------------------------------------------------===//
2667
2668unsigned KernelsOp::getNumDataOperands() {
2669 return getDataClauseOperands().size();
2670}
2671
2672Value KernelsOp::getDataOperand(unsigned i) {
2673 unsigned numOptional = getAsyncOperands().size();
2674 numOptional += getWaitOperands().size();
2675 numOptional += getNumGangs().size();
2676 numOptional += getNumWorkers().size();
2677 numOptional += getVectorLength().size();
2678 numOptional += getIfCond() ? 1 : 0;
2679 numOptional += getSelfCond() ? 1 : 0;
2680 return getOperand(numOptional + i);
2681}
2682
2683bool acc::KernelsOp::hasAsyncOnly() {
2684 return hasAsyncOnly(mlir::acc::DeviceType::None);
2685}
2686
2687bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2688 return hasDeviceType(getAsyncOnly(), deviceType);
2689}
2690
2691mlir::Value acc::KernelsOp::getAsyncValue() {
2692 return getAsyncValue(mlir::acc::DeviceType::None);
2693}
2694
2695mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2697 getAsyncOperands(), deviceType);
2698}
2699
2700mlir::Value acc::KernelsOp::getNumWorkersValue() {
2701 return getNumWorkersValue(mlir::acc::DeviceType::None);
2702}
2703
2705acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2706 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2707 deviceType);
2708}
2709
2710mlir::Value acc::KernelsOp::getVectorLengthValue() {
2711 return getVectorLengthValue(mlir::acc::DeviceType::None);
2712}
2713
2715acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2716 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2717 getVectorLength(), deviceType);
2718}
2719
2720mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2721 return getNumGangsValues(mlir::acc::DeviceType::None);
2722}
2723
2725KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2726 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2727 getNumGangsSegments(), deviceType);
2728}
2729
2730bool acc::KernelsOp::hasWaitOnly() {
2731 return hasWaitOnly(mlir::acc::DeviceType::None);
2732}
2733
2734bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2735 return hasDeviceType(getWaitOnly(), deviceType);
2736}
2737
2738mlir::Operation::operand_range KernelsOp::getWaitValues() {
2739 return getWaitValues(mlir::acc::DeviceType::None);
2740}
2741
2743KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2745 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2746 getHasWaitDevnum(), deviceType);
2747}
2748
2749mlir::Value KernelsOp::getWaitDevnum() {
2750 return getWaitDevnum(mlir::acc::DeviceType::None);
2751}
2752
2753mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2754 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2755 getWaitOperandsSegments(), getHasWaitDevnum(),
2756 deviceType);
2757}
2758
2759LogicalResult acc::KernelsOp::verify() {
2761 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2762 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2763 return failure();
2764
2766 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2767 getWaitOperandsDeviceTypeAttr(), "wait")))
2768 return failure();
2769
2770 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2771 getNumWorkersDeviceTypeAttr(),
2772 "num_workers")))
2773 return failure();
2774
2775 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2776 getVectorLengthDeviceTypeAttr(),
2777 "vector_length")))
2778 return failure();
2779
2781 getAsyncOperandsDeviceTypeAttr(),
2782 "async")))
2783 return failure();
2784
2786 return failure();
2787
2788 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2789}
2790
2791void acc::KernelsOp::addPrivatization(MLIRContext *context,
2792 mlir::acc::PrivateOp op,
2793 mlir::acc::PrivateRecipeOp recipe) {
2794 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2795 getPrivateOperandsMutable().append(op.getResult());
2796}
2797
2798void acc::KernelsOp::addFirstPrivatization(
2799 MLIRContext *context, mlir::acc::FirstprivateOp op,
2800 mlir::acc::FirstprivateRecipeOp recipe) {
2801 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2802 getFirstprivateOperandsMutable().append(op.getResult());
2803}
2804
2805void acc::KernelsOp::addReduction(MLIRContext *context,
2806 mlir::acc::ReductionOp op,
2807 mlir::acc::ReductionRecipeOp recipe) {
2808 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2809 getReductionOperandsMutable().append(op.getResult());
2810}
2811
2812void acc::KernelsOp::addNumWorkersOperand(
2813 MLIRContext *context, mlir::Value newValue,
2814 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2815 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2816 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2817 getNumWorkersMutable()));
2818}
2819
2820void acc::KernelsOp::addVectorLengthOperand(
2821 MLIRContext *context, mlir::Value newValue,
2822 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2823 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2824 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2825 getVectorLengthMutable()));
2826}
2827void acc::KernelsOp::addAsyncOnly(
2828 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2829 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2830 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2831}
2832
2833void acc::KernelsOp::addAsyncOperand(
2834 MLIRContext *context, mlir::Value newValue,
2835 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2836 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2837 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2838 getAsyncOperandsMutable()));
2839}
2840
2841void acc::KernelsOp::addNumGangsOperands(
2842 MLIRContext *context, mlir::ValueRange newValues,
2843 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2845 if (getNumGangsSegmentsAttr())
2846 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2847
2848 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2850 getNumGangsMutable(), segments));
2851
2852 setNumGangsSegments(segments);
2853}
2854
2855void acc::KernelsOp::addWaitOnly(
2856 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2857 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2858 effectiveDeviceTypes));
2859}
2860void acc::KernelsOp::addWaitOperands(
2861 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2862 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2863
2865 if (getWaitOperandsSegments())
2866 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2867
2868 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2869 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2870 getWaitOperandsMutable(), segments));
2871 setWaitOperandsSegments(segments);
2872
2874 if (getHasWaitDevnumAttr())
2875 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2876 hasDevnums.insert(
2877 hasDevnums.end(),
2878 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2879 mlir::BoolAttr::get(context, hasDevnum));
2880 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2881}
2882
2883//===----------------------------------------------------------------------===//
2884// HostDataOp
2885//===----------------------------------------------------------------------===//
2886
2887LogicalResult acc::HostDataOp::verify() {
2888 if (getDataClauseOperands().empty())
2889 return emitError("at least one operand must appear on the host_data "
2890 "operation");
2891
2892 for (mlir::Value operand : getDataClauseOperands())
2893 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2894 return emitError("expect data entry operation as defining op");
2895 return success();
2896}
2897
2898void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2899 MLIRContext *context) {
2900 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2901}
2902
2903//===----------------------------------------------------------------------===//
2904// KernelEnvironmentOp
2905//===----------------------------------------------------------------------===//
2906
2907void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2908 RewritePatternSet &results, MLIRContext *context) {
2909 results.add<RemoveEmptyKernelEnvironment>(context);
2910}
2911
2912//===----------------------------------------------------------------------===//
2913// LoopOp
2914//===----------------------------------------------------------------------===//
2915
2916static ParseResult parseGangValue(
2917 OpAsmParser &parser, llvm::StringRef keyword,
2920 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2921 bool &needCommaBetweenValues, bool &newValue) {
2922 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2923 if (parser.parseEqual())
2924 return failure();
2925 if (parser.parseOperand(operands.emplace_back()) ||
2926 parser.parseColonType(types.emplace_back()))
2927 return failure();
2928 attributes.push_back(gangArgType);
2929 needCommaBetweenValues = true;
2930 newValue = true;
2931 }
2932 return success();
2933}
2934
2935static ParseResult parseGangClause(
2936 OpAsmParser &parser,
2938 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2939 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2940 mlir::ArrayAttr &gangOnlyDeviceType) {
2941 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2942 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2943 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2945 bool needCommaBetweenValues = false;
2946 bool needCommaBeforeOperands = false;
2947
2948 if (failed(parser.parseOptionalLParen())) {
2949 // Gang only keyword
2950 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2951 parser.getContext(), mlir::acc::DeviceType::None));
2952 gangOnlyDeviceType =
2953 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2954 return success();
2955 }
2956
2957 // Parse gang only attributes
2958 if (succeeded(parser.parseOptionalLSquare())) {
2959 // Parse gang only attributes
2960 if (failed(parser.parseCommaSeparatedList([&]() {
2961 if (parser.parseAttribute(
2962 gangOnlyDeviceTypeAttributes.emplace_back()))
2963 return failure();
2964 return success();
2965 })))
2966 return failure();
2967 if (parser.parseRSquare())
2968 return failure();
2969 needCommaBeforeOperands = true;
2970 }
2971
2972 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2973 mlir::acc::GangArgType::Num);
2974 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2975 mlir::acc::GangArgType::Dim);
2976 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2977 parser.getContext(), mlir::acc::GangArgType::Static);
2978
2979 do {
2980 if (needCommaBeforeOperands) {
2981 needCommaBeforeOperands = false;
2982 continue;
2983 }
2984
2985 if (failed(parser.parseLBrace()))
2986 return failure();
2987
2988 int32_t crtOperandsSize = gangOperands.size();
2989 while (true) {
2990 bool newValue = false;
2991 bool needValue = false;
2992 if (needCommaBetweenValues) {
2993 if (succeeded(parser.parseOptionalComma()))
2994 needValue = true; // expect a new value after comma.
2995 else
2996 break;
2997 }
2998
2999 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3000 gangOperands, gangOperandsType,
3001 gangArgTypeAttributes, argNum,
3002 needCommaBetweenValues, newValue)))
3003 return failure();
3004 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3005 gangOperands, gangOperandsType,
3006 gangArgTypeAttributes, argDim,
3007 needCommaBetweenValues, newValue)))
3008 return failure();
3009 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3010 gangOperands, gangOperandsType,
3011 gangArgTypeAttributes, argStatic,
3012 needCommaBetweenValues, newValue)))
3013 return failure();
3014
3015 if (!newValue && needValue) {
3016 parser.emitError(parser.getCurrentLocation(),
3017 "new value expected after comma");
3018 return failure();
3019 }
3020
3021 if (!newValue)
3022 break;
3023 }
3024
3025 if (gangOperands.empty())
3026 return parser.emitError(
3027 parser.getCurrentLocation(),
3028 "expect at least one of num, dim or static values");
3029
3030 if (failed(parser.parseRBrace()))
3031 return failure();
3032
3033 if (succeeded(parser.parseOptionalLSquare())) {
3034 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3035 parser.parseRSquare())
3036 return failure();
3037 } else {
3038 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3039 parser.getContext(), mlir::acc::DeviceType::None));
3040 }
3041
3042 seg.push_back(gangOperands.size() - crtOperandsSize);
3043
3044 } while (succeeded(parser.parseOptionalComma()));
3045
3046 if (failed(parser.parseRParen()))
3047 return failure();
3048
3049 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3050 gangArgTypeAttributes.end());
3051 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3052 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3053
3055 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3056 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3057
3058 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3059 return success();
3060}
3061
3063 mlir::OperandRange operands, mlir::TypeRange types,
3064 std::optional<mlir::ArrayAttr> gangArgTypes,
3065 std::optional<mlir::ArrayAttr> deviceTypes,
3066 std::optional<mlir::DenseI32ArrayAttr> segments,
3067 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3068
3069 if (operands.begin() == operands.end() &&
3070 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3071 return;
3072 }
3073
3074 p << "(";
3075
3076 printDeviceTypes(p, gangOnlyDeviceTypes);
3077
3078 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3079 hasDeviceTypeValues(deviceTypes))
3080 p << ", ";
3081
3082 if (hasDeviceTypeValues(deviceTypes)) {
3083 unsigned opIdx = 0;
3084 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3085 p << "{";
3086 llvm::interleaveComma(
3087 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3088 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3089 (*gangArgTypes)[opIdx]);
3090 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3091 p << LoopOp::getGangNumKeyword();
3092 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3093 p << LoopOp::getGangDimKeyword();
3094 else if (gangArgTypeAttr.getValue() ==
3095 mlir::acc::GangArgType::Static)
3096 p << LoopOp::getGangStaticKeyword();
3097 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3098 ++opIdx;
3099 });
3100 p << "}";
3101 printSingleDeviceType(p, it.value());
3102 });
3103 }
3104 p << ")";
3105}
3106
3108 std::optional<mlir::ArrayAttr> segments,
3109 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3110 if (!segments)
3111 return false;
3112 for (auto attr : *segments) {
3113 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3114 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3115 return true;
3116 }
3117 return false;
3118}
3119
3120/// Check for duplicates in the DeviceType array attribute.
3121/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3122static std::optional<mlir::acc::DeviceType>
3123checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3124 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3125 if (!deviceTypes)
3126 return std::nullopt;
3127 for (auto attr : deviceTypes) {
3128 auto deviceTypeAttr =
3129 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3130 if (!deviceTypeAttr)
3131 return mlir::acc::DeviceType::None;
3132 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3133 return deviceTypeAttr.getValue();
3134 }
3135 return std::nullopt;
3136}
3137
3138LogicalResult acc::LoopOp::verify() {
3139 if (getUpperbound().size() != getStep().size())
3140 return emitError() << "number of upperbounds expected to be the same as "
3141 "number of steps";
3142
3143 if (getUpperbound().size() != getLowerbound().size())
3144 return emitError() << "number of upperbounds expected to be the same as "
3145 "number of lowerbounds";
3146
3147 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3148 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3149 return emitError() << "inclusiveUpperbound size is expected to be the same"
3150 << " as upperbound size";
3151
3152 // Check collapse
3153 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3154 return emitOpError() << "collapse device_type attr must be define when"
3155 << " collapse attr is present";
3156
3157 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3158 getCollapseAttr().getValue().size() !=
3159 getCollapseDeviceTypeAttr().getValue().size())
3160 return emitOpError() << "collapse attribute count must match collapse"
3161 << " device_type count";
3162 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3163 return emitOpError() << "duplicate device_type `"
3164 << acc::stringifyDeviceType(*duplicateDeviceType)
3165 << "` found in collapseDeviceType attribute";
3166
3167 // Check gang
3168 if (!getGangOperands().empty()) {
3169 if (!getGangOperandsArgType())
3170 return emitOpError() << "gangOperandsArgType attribute must be defined"
3171 << " when gang operands are present";
3172
3173 if (getGangOperands().size() !=
3174 getGangOperandsArgTypeAttr().getValue().size())
3175 return emitOpError() << "gangOperandsArgType attribute count must match"
3176 << " gangOperands count";
3177 }
3178 if (getGangAttr()) {
3179 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3180 return emitOpError() << "duplicate device_type `"
3181 << acc::stringifyDeviceType(*duplicateDeviceType)
3182 << "` found in gang attribute";
3183 }
3184
3186 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3187 getGangOperandsDeviceTypeAttr(), "gang")))
3188 return failure();
3189
3190 // Check worker
3191 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3192 return emitOpError() << "duplicate device_type `"
3193 << acc::stringifyDeviceType(*duplicateDeviceType)
3194 << "` found in worker attribute";
3195 if (auto duplicateDeviceType =
3196 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3197 return emitOpError() << "duplicate device_type `"
3198 << acc::stringifyDeviceType(*duplicateDeviceType)
3199 << "` found in workerNumOperandsDeviceType attribute";
3200 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3201 getWorkerNumOperandsDeviceTypeAttr(),
3202 "worker")))
3203 return failure();
3204
3205 // Check vector
3206 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3207 return emitOpError() << "duplicate device_type `"
3208 << acc::stringifyDeviceType(*duplicateDeviceType)
3209 << "` found in vector attribute";
3210 if (auto duplicateDeviceType =
3211 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3212 return emitOpError() << "duplicate device_type `"
3213 << acc::stringifyDeviceType(*duplicateDeviceType)
3214 << "` found in vectorOperandsDeviceType attribute";
3215 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3216 getVectorOperandsDeviceTypeAttr(),
3217 "vector")))
3218 return failure();
3219
3221 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3222 getTileOperandsDeviceTypeAttr(), "tile")))
3223 return failure();
3224
3225 // auto, independent and seq attribute are mutually exclusive.
3226 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3227 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3228 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3229 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3230 return emitError() << "only one of auto, independent, seq can be present "
3231 "at the same time";
3232 }
3233
3234 // Check that at least one of auto, independent, or seq is present
3235 // for the device-independent default clauses.
3236 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3237 return attr.getValue() == mlir::acc::DeviceType::None;
3238 };
3239 bool hasDefaultSeq =
3240 getSeqAttr()
3241 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3242 hasDeviceNone)
3243 : false;
3244 bool hasDefaultIndependent =
3245 getIndependentAttr()
3246 ? llvm::any_of(
3247 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3248 hasDeviceNone)
3249 : false;
3250 bool hasDefaultAuto =
3251 getAuto_Attr()
3252 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3253 hasDeviceNone)
3254 : false;
3255 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3256 return emitError()
3257 << "at least one of auto, independent, seq must be present";
3258 }
3259
3260 // Gang, worker and vector are incompatible with seq.
3261 if (getSeqAttr()) {
3262 for (auto attr : getSeqAttr()) {
3263 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3264 if (hasVector(deviceTypeAttr.getValue()) ||
3265 getVectorValue(deviceTypeAttr.getValue()) ||
3266 hasWorker(deviceTypeAttr.getValue()) ||
3267 getWorkerValue(deviceTypeAttr.getValue()) ||
3268 hasGang(deviceTypeAttr.getValue()) ||
3269 getGangValue(mlir::acc::GangArgType::Num,
3270 deviceTypeAttr.getValue()) ||
3271 getGangValue(mlir::acc::GangArgType::Dim,
3272 deviceTypeAttr.getValue()) ||
3273 getGangValue(mlir::acc::GangArgType::Static,
3274 deviceTypeAttr.getValue()))
3275 return emitError() << "gang, worker or vector cannot appear with seq";
3276 }
3277 }
3278
3279 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3280 mlir::acc::PrivateRecipeOp>(
3281 *this, getPrivateOperands(), "private")))
3282 return failure();
3283
3284 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3285 mlir::acc::FirstprivateRecipeOp>(
3286 *this, getFirstprivateOperands(), "firstprivate")))
3287 return failure();
3288
3289 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3290 mlir::acc::ReductionRecipeOp>(
3291 *this, getReductionOperands(), "reduction")))
3292 return failure();
3293
3294 if (getCombined().has_value() &&
3295 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3296 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3297 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3298 return emitError("unexpected combined constructs attribute");
3299 }
3300
3301 // Check non-empty body().
3302 if (getRegion().empty())
3303 return emitError("expected non-empty body.");
3304
3305 if (getUnstructured()) {
3306 if (!isContainerLike())
3307 return emitError(
3308 "unstructured acc.loop must not have induction variables");
3309 } else if (isContainerLike()) {
3310 // When it is container-like - it is expected to hold a loop-like operation.
3311 // Obtain the maximum collapse count - we use this to check that there
3312 // are enough loops contained.
3313 uint64_t collapseCount = getCollapseValue().value_or(1);
3314 if (getCollapseAttr()) {
3315 for (auto collapseEntry : getCollapseAttr()) {
3316 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3317 if (intAttr.getValue().getZExtValue() > collapseCount)
3318 collapseCount = intAttr.getValue().getZExtValue();
3319 }
3320 }
3321
3322 // We want to check that we find enough loop-like operations inside.
3323 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3324 // level.
3325 mlir::Operation *expectedParent = this->getOperation();
3326 bool foundSibling = false;
3327 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3328 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3329 // This effectively checks that we are not looking at a sibling loop.
3330 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3331 expectedParent) {
3332 foundSibling = true;
3334 }
3335
3336 collapseCount--;
3337 expectedParent = op;
3338 }
3339 // We found enough contained loops.
3340 if (collapseCount == 0)
3343 });
3344
3345 if (foundSibling)
3346 return emitError("found sibling loops inside container-like acc.loop");
3347 if (collapseCount != 0)
3348 return emitError("failed to find enough loop-like operations inside "
3349 "container-like acc.loop");
3350 }
3351
3352 return success();
3353}
3354
3355unsigned LoopOp::getNumDataOperands() {
3356 return getReductionOperands().size() + getPrivateOperands().size() +
3357 getFirstprivateOperands().size();
3358}
3359
3360Value LoopOp::getDataOperand(unsigned i) {
3361 unsigned numOptional =
3362 getLowerbound().size() + getUpperbound().size() + getStep().size();
3363 numOptional += getGangOperands().size();
3364 numOptional += getVectorOperands().size();
3365 numOptional += getWorkerNumOperands().size();
3366 numOptional += getTileOperands().size();
3367 numOptional += getCacheOperands().size();
3368 return getOperand(numOptional + i);
3369}
3370
3371bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3372
3373bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3374 return hasDeviceType(getAuto_(), deviceType);
3375}
3376
3377bool LoopOp::hasIndependent() {
3378 return hasIndependent(mlir::acc::DeviceType::None);
3379}
3380
3381bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3382 return hasDeviceType(getIndependent(), deviceType);
3383}
3384
3385bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3386
3387bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3388 return hasDeviceType(getSeq(), deviceType);
3389}
3390
3391mlir::Value LoopOp::getVectorValue() {
3392 return getVectorValue(mlir::acc::DeviceType::None);
3393}
3394
3395mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3396 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3397 getVectorOperands(), deviceType);
3398}
3399
3400bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3401
3402bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3403 return hasDeviceType(getVector(), deviceType);
3404}
3405
3406mlir::Value LoopOp::getWorkerValue() {
3407 return getWorkerValue(mlir::acc::DeviceType::None);
3408}
3409
3410mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3411 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3412 getWorkerNumOperands(), deviceType);
3413}
3414
3415bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3416
3417bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3418 return hasDeviceType(getWorker(), deviceType);
3419}
3420
3421mlir::Operation::operand_range LoopOp::getTileValues() {
3422 return getTileValues(mlir::acc::DeviceType::None);
3423}
3424
3426LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3427 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3428 getTileOperandsSegments(), deviceType);
3429}
3430
3431std::optional<int64_t> LoopOp::getCollapseValue() {
3432 return getCollapseValue(mlir::acc::DeviceType::None);
3433}
3434
3435std::optional<int64_t>
3436LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3437 if (!getCollapseAttr())
3438 return std::nullopt;
3439 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3440 auto intAttr =
3441 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3442 return intAttr.getValue().getZExtValue();
3443 }
3444 return std::nullopt;
3445}
3446
3447mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3448 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3449}
3450
3451mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3452 mlir::acc::DeviceType deviceType) {
3453 if (getGangOperands().empty())
3454 return {};
3455 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3456 int32_t nbOperandsBefore = 0;
3457 for (unsigned i = 0; i < *pos; ++i)
3458 nbOperandsBefore += (*getGangOperandsSegments())[i];
3460 getGangOperands()
3461 .drop_front(nbOperandsBefore)
3462 .take_front((*getGangOperandsSegments())[*pos]);
3463
3464 int32_t argTypeIdx = nbOperandsBefore;
3465 for (auto value : values) {
3466 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3467 (*getGangOperandsArgType())[argTypeIdx]);
3468 if (gangArgTypeAttr.getValue() == gangArgType)
3469 return value;
3470 ++argTypeIdx;
3471 }
3472 }
3473 return {};
3474}
3475
3476bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3477
3478bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3479 return hasDeviceType(getGang(), deviceType);
3480}
3481
3482llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3483 return {&getRegion()};
3484}
3485
3486/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3487/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3488/// `(` ssa-id-and-type-list `)`
3489/// region
3490ParseResult
3493 SmallVectorImpl<Type> &lowerboundType,
3495 SmallVectorImpl<Type> &upperboundType,
3497 SmallVectorImpl<Type> &stepType) {
3498
3500 if (succeeded(
3501 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3502 if (parser.parseLParen() ||
3503 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3504 /*allowType=*/true) ||
3505 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3506 parser.parseOperandList(lowerbound, inductionVars.size(),
3508 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3509 parser.parseKeyword("to") || parser.parseLParen() ||
3510 parser.parseOperandList(upperbound, inductionVars.size(),
3512 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3513 parser.parseKeyword("step") || parser.parseLParen() ||
3514 parser.parseOperandList(step, inductionVars.size(),
3516 parser.parseColonTypeList(stepType) || parser.parseRParen())
3517 return failure();
3518 }
3519 return parser.parseRegion(region, inductionVars);
3520}
3521
3523 ValueRange lowerbound, TypeRange lowerboundType,
3524 ValueRange upperbound, TypeRange upperboundType,
3525 ValueRange steps, TypeRange stepType) {
3526 ValueRange regionArgs = region.front().getArguments();
3527 if (!regionArgs.empty()) {
3528 p << acc::LoopOp::getControlKeyword() << "(";
3529 llvm::interleaveComma(regionArgs, p,
3530 [&p](Value v) { p << v << " : " << v.getType(); });
3531 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3532 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3533 << " : " << stepType << ") ";
3534 }
3535 p.printRegion(region, /*printEntryBlockArgs=*/false);
3536}
3537
3538void acc::LoopOp::addSeq(MLIRContext *context,
3539 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3540 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3541 effectiveDeviceTypes));
3542}
3543
3544void acc::LoopOp::addIndependent(
3545 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3546 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3547 context, getIndependentAttr(), effectiveDeviceTypes));
3548}
3549
3550void acc::LoopOp::addAuto(MLIRContext *context,
3551 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3552 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3553 effectiveDeviceTypes));
3554}
3555
3556void acc::LoopOp::setCollapseForDeviceTypes(
3557 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3558 llvm::APInt value) {
3561
3562 assert((getCollapseAttr() == nullptr) ==
3563 (getCollapseDeviceTypeAttr() == nullptr));
3564 assert(value.getBitWidth() == 64);
3565
3566 if (getCollapseAttr()) {
3567 for (const auto &existing :
3568 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3569 newValues.push_back(std::get<0>(existing));
3570 newDeviceTypes.push_back(std::get<1>(existing));
3571 }
3572 }
3573
3574 if (effectiveDeviceTypes.empty()) {
3575 // If the effective device-types list is empty, this is before there are any
3576 // being applied by device_type, so this should be added as a 'none'.
3577 newValues.push_back(
3578 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3579 newDeviceTypes.push_back(
3580 acc::DeviceTypeAttr::get(context, DeviceType::None));
3581 } else {
3582 for (DeviceType dt : effectiveDeviceTypes) {
3583 newValues.push_back(
3584 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3585 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3586 }
3587 }
3588
3589 setCollapseAttr(ArrayAttr::get(context, newValues));
3590 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3591}
3592
3593void acc::LoopOp::setTileForDeviceTypes(
3594 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3595 ValueRange values) {
3597 if (getTileOperandsSegments())
3598 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3599
3600 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3601 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3602 getTileOperandsMutable(), segments));
3603
3604 setTileOperandsSegments(segments);
3605}
3606
3607void acc::LoopOp::addVectorOperand(
3608 MLIRContext *context, mlir::Value newValue,
3609 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3610 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3611 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3612 newValue, getVectorOperandsMutable()));
3613}
3614
3615void acc::LoopOp::addEmptyVector(
3616 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3617 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3618 effectiveDeviceTypes));
3619}
3620
3621void acc::LoopOp::addWorkerNumOperand(
3622 MLIRContext *context, mlir::Value newValue,
3623 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3624 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3625 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3626 newValue, getWorkerNumOperandsMutable()));
3627}
3628
3629void acc::LoopOp::addEmptyWorker(
3630 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3631 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3632 effectiveDeviceTypes));
3633}
3634
3635void acc::LoopOp::addEmptyGang(
3636 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3637 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3638 effectiveDeviceTypes));
3639}
3640
3641bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3642 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3643 return attr.getValue() == dt;
3644 };
3645 auto testFromArr = [=](ArrayAttr arr) -> bool {
3646 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3647 };
3648
3649 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3650 return true;
3651 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3652 return true;
3653 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3654 return true;
3655
3656 return false;
3657}
3658
3659bool acc::LoopOp::hasDefaultGangWorkerVector() {
3660 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3661 hasGang() || getGangValue(GangArgType::Num) ||
3662 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3663}
3664
3665acc::LoopParMode
3666acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3667 if (hasSeq(deviceType))
3668 return LoopParMode::loop_seq;
3669 if (hasAuto(deviceType))
3670 return LoopParMode::loop_auto;
3671 if (hasIndependent(deviceType))
3672 return LoopParMode::loop_independent;
3673 if (hasSeq())
3674 return LoopParMode::loop_seq;
3675 if (hasAuto())
3676 return LoopParMode::loop_auto;
3677 assert(hasIndependent() &&
3678 "loop must have default auto, seq, or independent");
3679 return LoopParMode::loop_independent;
3680}
3681
3682void acc::LoopOp::addGangOperands(
3683 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3686 if (std::optional<ArrayRef<int32_t>> existingSegments =
3687 getGangOperandsSegments())
3688 llvm::copy(*existingSegments, std::back_inserter(segments));
3689
3690 unsigned beforeCount = segments.size();
3691
3692 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3693 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3694 getGangOperandsMutable(), segments));
3695
3696 setGangOperandsSegments(segments);
3697
3698 // This is a bit of extra work to make sure we update the 'types' correctly by
3699 // adding to the types collection the correct number of times. We could
3700 // potentially add something similar to the
3701 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3702 // excessive for a one-off case.
3703 unsigned numAdded = segments.size() - beforeCount;
3704
3705 if (numAdded > 0) {
3707 if (getGangOperandsArgTypeAttr())
3708 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3709
3710 for (auto i : llvm::index_range(0u, numAdded)) {
3711 llvm::transform(argTypes, std::back_inserter(gangTypes),
3712 [=](mlir::acc::GangArgType gangTy) {
3713 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3714 });
3715 (void)i;
3716 }
3717
3718 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3719 }
3720}
3721
3722void acc::LoopOp::addPrivatization(MLIRContext *context,
3723 mlir::acc::PrivateOp op,
3724 mlir::acc::PrivateRecipeOp recipe) {
3725 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3726 getPrivateOperandsMutable().append(op.getResult());
3727}
3728
3729void acc::LoopOp::addFirstPrivatization(
3730 MLIRContext *context, mlir::acc::FirstprivateOp op,
3731 mlir::acc::FirstprivateRecipeOp recipe) {
3732 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3733 getFirstprivateOperandsMutable().append(op.getResult());
3734}
3735
3736void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3737 mlir::acc::ReductionRecipeOp recipe) {
3738 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3739 getReductionOperandsMutable().append(op.getResult());
3740}
3741
3742//===----------------------------------------------------------------------===//
3743// DataOp
3744//===----------------------------------------------------------------------===//
3745
3746LogicalResult acc::DataOp::verify() {
3747 // 2.6.5. Data Construct restriction
3748 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3749 // attach, or default clause must appear on a data construct.
3750 if (getOperands().empty() && !getDefaultAttr())
3751 return emitError("at least one operand or the default attribute "
3752 "must appear on the data operation");
3753
3754 for (mlir::Value operand : getDataClauseOperands())
3755 if (isa<BlockArgument>(operand) ||
3756 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3757 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3758 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3759 operand.getDefiningOp()))
3760 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3761 "as defining op");
3762
3764 return failure();
3765
3766 return success();
3767}
3768
3769unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3770
3771Value DataOp::getDataOperand(unsigned i) {
3772 unsigned numOptional = getIfCond() ? 1 : 0;
3773 numOptional += getAsyncOperands().size() ? 1 : 0;
3774 numOptional += getWaitOperands().size();
3775 return getOperand(numOptional + i);
3776}
3777
3778bool acc::DataOp::hasAsyncOnly() {
3779 return hasAsyncOnly(mlir::acc::DeviceType::None);
3780}
3781
3782bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3783 return hasDeviceType(getAsyncOnly(), deviceType);
3784}
3785
3786mlir::Value DataOp::getAsyncValue() {
3787 return getAsyncValue(mlir::acc::DeviceType::None);
3788}
3789
3790mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3792 getAsyncOperands(), deviceType);
3793}
3794
3795bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3796
3797bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3798 return hasDeviceType(getWaitOnly(), deviceType);
3799}
3800
3801mlir::Operation::operand_range DataOp::getWaitValues() {
3802 return getWaitValues(mlir::acc::DeviceType::None);
3803}
3804
3806DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3808 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3809 getHasWaitDevnum(), deviceType);
3810}
3811
3812mlir::Value DataOp::getWaitDevnum() {
3813 return getWaitDevnum(mlir::acc::DeviceType::None);
3814}
3815
3816mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3817 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3818 getWaitOperandsSegments(), getHasWaitDevnum(),
3819 deviceType);
3820}
3821
3822void acc::DataOp::addAsyncOnly(
3823 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3824 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3825 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3826}
3827
3828void acc::DataOp::addAsyncOperand(
3829 MLIRContext *context, mlir::Value newValue,
3830 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3831 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3832 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3833 getAsyncOperandsMutable()));
3834}
3835
3836void acc::DataOp::addWaitOnly(MLIRContext *context,
3837 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3838 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3839 effectiveDeviceTypes));
3840}
3841
3842void acc::DataOp::addWaitOperands(
3843 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3844 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3845
3847 if (getWaitOperandsSegments())
3848 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3849
3850 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3851 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3852 getWaitOperandsMutable(), segments));
3853 setWaitOperandsSegments(segments);
3854
3856 if (getHasWaitDevnumAttr())
3857 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3858 hasDevnums.insert(
3859 hasDevnums.end(),
3860 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3861 mlir::BoolAttr::get(context, hasDevnum));
3862 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3863}
3864
3865//===----------------------------------------------------------------------===//
3866// ExitDataOp
3867//===----------------------------------------------------------------------===//
3868
3869LogicalResult acc::ExitDataOp::verify() {
3870 // 2.6.6. Data Exit Directive restriction
3871 // At least one copyout, delete, or detach clause must appear on an exit data
3872 // directive.
3873 if (getDataClauseOperands().empty())
3874 return emitError("at least one operand must be present in dataOperands on "
3875 "the exit data operation");
3876
3877 // The async attribute represent the async clause without value. Therefore the
3878 // attribute and operand cannot appear at the same time.
3879 if (getAsyncOperand() && getAsync())
3880 return emitError("async attribute cannot appear with asyncOperand");
3881
3882 // The wait attribute represent the wait clause without values. Therefore the
3883 // attribute and operands cannot appear at the same time.
3884 if (!getWaitOperands().empty() && getWait())
3885 return emitError("wait attribute cannot appear with waitOperands");
3886
3887 if (getWaitDevnum() && getWaitOperands().empty())
3888 return emitError("wait_devnum cannot appear without waitOperands");
3889
3890 return success();
3891}
3892
3893unsigned ExitDataOp::getNumDataOperands() {
3894 return getDataClauseOperands().size();
3895}
3896
3897Value ExitDataOp::getDataOperand(unsigned i) {
3898 unsigned numOptional = getIfCond() ? 1 : 0;
3899 numOptional += getAsyncOperand() ? 1 : 0;
3900 numOptional += getWaitDevnum() ? 1 : 0;
3901 return getOperand(getWaitOperands().size() + numOptional + i);
3902}
3903
3904void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3905 MLIRContext *context) {
3906 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3907}
3908
3909void ExitDataOp::addAsyncOnly(MLIRContext *context,
3910 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3911 assert(effectiveDeviceTypes.empty());
3912 assert(!getAsyncAttr());
3913 assert(!getAsyncOperand());
3914
3915 setAsyncAttr(mlir::UnitAttr::get(context));
3916}
3917
3918void ExitDataOp::addAsyncOperand(
3919 MLIRContext *context, mlir::Value newValue,
3920 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3921 assert(effectiveDeviceTypes.empty());
3922 assert(!getAsyncAttr());
3923 assert(!getAsyncOperand());
3924
3925 getAsyncOperandMutable().append(newValue);
3926}
3927
3928void ExitDataOp::addWaitOnly(MLIRContext *context,
3929 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3930 assert(effectiveDeviceTypes.empty());
3931 assert(!getWaitAttr());
3932 assert(getWaitOperands().empty());
3933 assert(!getWaitDevnum());
3934
3935 setWaitAttr(mlir::UnitAttr::get(context));
3936}
3937
3938void ExitDataOp::addWaitOperands(
3939 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3940 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3941 assert(effectiveDeviceTypes.empty());
3942 assert(!getWaitAttr());
3943 assert(getWaitOperands().empty());
3944 assert(!getWaitDevnum());
3945
3946 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3947 // operands list.
3948 if (hasDevnum) {
3949 getWaitDevnumMutable().append(newValues.front());
3950 newValues = newValues.drop_front();
3951 }
3952
3953 getWaitOperandsMutable().append(newValues);
3954}
3955
3956//===----------------------------------------------------------------------===//
3957// EnterDataOp
3958//===----------------------------------------------------------------------===//
3959
3960LogicalResult acc::EnterDataOp::verify() {
3961 // 2.6.6. Data Enter Directive restriction
3962 // At least one copyin, create, or attach clause must appear on an enter data
3963 // directive.
3964 if (getDataClauseOperands().empty())
3965 return emitError("at least one operand must be present in dataOperands on "
3966 "the enter data operation");
3967
3968 // The async attribute represent the async clause without value. Therefore the
3969 // attribute and operand cannot appear at the same time.
3970 if (getAsyncOperand() && getAsync())
3971 return emitError("async attribute cannot appear with asyncOperand");
3972
3973 // The wait attribute represent the wait clause without values. Therefore the
3974 // attribute and operands cannot appear at the same time.
3975 if (!getWaitOperands().empty() && getWait())
3976 return emitError("wait attribute cannot appear with waitOperands");
3977
3978 if (getWaitDevnum() && getWaitOperands().empty())
3979 return emitError("wait_devnum cannot appear without waitOperands");
3980
3981 for (mlir::Value operand : getDataClauseOperands())
3982 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3983 operand.getDefiningOp()))
3984 return emitError("expect data entry operation as defining op");
3985
3986 return success();
3987}
3988
3989unsigned EnterDataOp::getNumDataOperands() {
3990 return getDataClauseOperands().size();
3991}
3992
3993Value EnterDataOp::getDataOperand(unsigned i) {
3994 unsigned numOptional = getIfCond() ? 1 : 0;
3995 numOptional += getAsyncOperand() ? 1 : 0;
3996 numOptional += getWaitDevnum() ? 1 : 0;
3997 return getOperand(getWaitOperands().size() + numOptional + i);
3998}
3999
4000void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4001 MLIRContext *context) {
4002 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4003}
4004
4005void EnterDataOp::addAsyncOnly(
4006 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4007 assert(effectiveDeviceTypes.empty());
4008 assert(!getAsyncAttr());
4009 assert(!getAsyncOperand());
4010
4011 setAsyncAttr(mlir::UnitAttr::get(context));
4012}
4013
4014void EnterDataOp::addAsyncOperand(
4015 MLIRContext *context, mlir::Value newValue,
4016 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4017 assert(effectiveDeviceTypes.empty());
4018 assert(!getAsyncAttr());
4019 assert(!getAsyncOperand());
4020
4021 getAsyncOperandMutable().append(newValue);
4022}
4023
4024void EnterDataOp::addWaitOnly(MLIRContext *context,
4025 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4026 assert(effectiveDeviceTypes.empty());
4027 assert(!getWaitAttr());
4028 assert(getWaitOperands().empty());
4029 assert(!getWaitDevnum());
4030
4031 setWaitAttr(mlir::UnitAttr::get(context));
4032}
4033
4034void EnterDataOp::addWaitOperands(
4035 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4036 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4037 assert(effectiveDeviceTypes.empty());
4038 assert(!getWaitAttr());
4039 assert(getWaitOperands().empty());
4040 assert(!getWaitDevnum());
4041
4042 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4043 // operands list.
4044 if (hasDevnum) {
4045 getWaitDevnumMutable().append(newValues.front());
4046 newValues = newValues.drop_front();
4047 }
4048
4049 getWaitOperandsMutable().append(newValues);
4050}
4051
4052//===----------------------------------------------------------------------===//
4053// AtomicReadOp
4054//===----------------------------------------------------------------------===//
4055
4056LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4057
4058//===----------------------------------------------------------------------===//
4059// AtomicWriteOp
4060//===----------------------------------------------------------------------===//
4061
4062LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4063
4064//===----------------------------------------------------------------------===//
4065// AtomicUpdateOp
4066//===----------------------------------------------------------------------===//
4067
4068LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4069 PatternRewriter &rewriter) {
4070 if (op.isNoOp()) {
4071 rewriter.eraseOp(op);
4072 return success();
4073 }
4074
4075 if (Value writeVal = op.getWriteOpVal()) {
4076 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4077 op.getIfCond());
4078 return success();
4079 }
4080
4081 return failure();
4082}
4083
4084LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4085
4086LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4087
4088//===----------------------------------------------------------------------===//
4089// AtomicCaptureOp
4090//===----------------------------------------------------------------------===//
4091
4092AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4093 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4094 return op;
4095 return dyn_cast<AtomicReadOp>(getSecondOp());
4096}
4097
4098AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4099 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4100 return op;
4101 return dyn_cast<AtomicWriteOp>(getSecondOp());
4102}
4103
4104AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4105 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4106 return op;
4107 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4108}
4109
4110LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4111
4112//===----------------------------------------------------------------------===//
4113// DeclareEnterOp
4114//===----------------------------------------------------------------------===//
4115
4116template <typename Op>
4117static LogicalResult
4119 bool requireAtLeastOneOperand = true) {
4120 if (operands.empty() && requireAtLeastOneOperand)
4121 return emitError(
4122 op->getLoc(),
4123 "at least one operand must appear on the declare operation");
4124
4125 for (mlir::Value operand : operands) {
4126 if (isa<BlockArgument>(operand) ||
4127 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4128 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4129 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4130 operand.getDefiningOp()))
4131 return op.emitError(
4132 "expect valid declare data entry operation or acc.getdeviceptr "
4133 "as defining op");
4134
4135 mlir::Value var{getVar(operand.getDefiningOp())};
4136 assert(var && "declare operands can only be data entry operations which "
4137 "must have var");
4138 (void)var;
4139 std::optional<mlir::acc::DataClause> dataClauseOptional{
4140 getDataClause(operand.getDefiningOp())};
4141 assert(dataClauseOptional.has_value() &&
4142 "declare operands can only be data entry operations which must have "
4143 "dataClause");
4144 (void)dataClauseOptional;
4145 }
4146
4147 return success();
4148}
4149
4150LogicalResult acc::DeclareEnterOp::verify() {
4151 return checkDeclareOperands(*this, this->getDataClauseOperands());
4152}
4153
4154//===----------------------------------------------------------------------===//
4155// DeclareExitOp
4156//===----------------------------------------------------------------------===//
4157
4158LogicalResult acc::DeclareExitOp::verify() {
4159 if (getToken())
4160 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4161 /*requireAtLeastOneOperand=*/false);
4162 return checkDeclareOperands(*this, this->getDataClauseOperands());
4163}
4164
4165//===----------------------------------------------------------------------===//
4166// DeclareOp
4167//===----------------------------------------------------------------------===//
4168
4169LogicalResult acc::DeclareOp::verify() {
4170 return checkDeclareOperands(*this, this->getDataClauseOperands());
4171}
4172
4173//===----------------------------------------------------------------------===//
4174// RoutineOp
4175//===----------------------------------------------------------------------===//
4176
4177static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4178 acc::DeviceType dtype) {
4179 unsigned parallelism = 0;
4180 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4181 parallelism += op.hasWorker(dtype) ? 1 : 0;
4182 parallelism += op.hasVector(dtype) ? 1 : 0;
4183 parallelism += op.hasSeq(dtype) ? 1 : 0;
4184 return parallelism;
4185}
4186
4187LogicalResult acc::RoutineOp::verify() {
4188 unsigned baseParallelism =
4189 getParallelismForDeviceType(*this, acc::DeviceType::None);
4190
4191 if (baseParallelism > 1)
4192 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4193 "be present at the same time";
4194
4195 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4196 ++dtypeInt) {
4197 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4198 if (dtype == acc::DeviceType::None)
4199 continue;
4200 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4201
4202 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4203 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4204 "be present at the same time for device_type `"
4205 << acc::stringifyDeviceType(dtype) << "`";
4206 }
4207
4208 return success();
4209}
4210
4211static ParseResult parseBindName(OpAsmParser &parser,
4212 mlir::ArrayAttr &bindIdName,
4213 mlir::ArrayAttr &bindStrName,
4214 mlir::ArrayAttr &deviceIdTypes,
4215 mlir::ArrayAttr &deviceStrTypes) {
4216 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4217 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4218 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4219 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4220
4221 if (failed(parser.parseCommaSeparatedList([&]() {
4222 mlir::Attribute newAttr;
4223 bool isSymbolRefAttr;
4224 auto parseResult = parser.parseAttribute(newAttr);
4225 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4226 bindIdNameAttrs.push_back(symbolRefAttr);
4227 isSymbolRefAttr = true;
4228 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4229 bindStrNameAttrs.push_back(stringAttr);
4230 isSymbolRefAttr = false;
4231 }
4232 if (parseResult)
4233 return failure();
4234 if (failed(parser.parseOptionalLSquare())) {
4235 if (isSymbolRefAttr) {
4236 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4237 parser.getContext(), mlir::acc::DeviceType::None));
4238 } else {
4239 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4240 parser.getContext(), mlir::acc::DeviceType::None));
4241 }
4242 } else {
4243 if (isSymbolRefAttr) {
4244 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4245 parser.parseRSquare())
4246 return failure();
4247 } else {
4248 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4249 parser.parseRSquare())
4250 return failure();
4251 }
4252 }
4253 return success();
4254 })))
4255 return failure();
4256
4257 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4258 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4259 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4260 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4261
4262 return success();
4263}
4264
4266 std::optional<mlir::ArrayAttr> bindIdName,
4267 std::optional<mlir::ArrayAttr> bindStrName,
4268 std::optional<mlir::ArrayAttr> deviceIdTypes,
4269 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4270 // Create combined vectors for all bind names and device types
4273
4274 // Append bindIdName and deviceIdTypes
4275 if (hasDeviceTypeValues(deviceIdTypes)) {
4276 allBindNames.append(bindIdName->begin(), bindIdName->end());
4277 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4278 }
4279
4280 // Append bindStrName and deviceStrTypes
4281 if (hasDeviceTypeValues(deviceStrTypes)) {
4282 allBindNames.append(bindStrName->begin(), bindStrName->end());
4283 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4284 }
4285
4286 // Print the combined sequence
4287 if (!allBindNames.empty())
4288 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4289 [&](const auto &pair) {
4290 p << std::get<0>(pair);
4291 printSingleDeviceType(p, std::get<1>(pair));
4292 });
4293}
4294
4295static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4296 mlir::ArrayAttr &gang,
4297 mlir::ArrayAttr &gangDim,
4298 mlir::ArrayAttr &gangDimDeviceTypes) {
4299
4300 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4301 gangDimDeviceTypeAttrs;
4302 bool needCommaBeforeOperands = false;
4303
4304 // Gang keyword only
4305 if (failed(parser.parseOptionalLParen())) {
4306 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4307 parser.getContext(), mlir::acc::DeviceType::None));
4308 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4309 return success();
4310 }
4311
4312 // Parse keyword only attributes
4313 if (succeeded(parser.parseOptionalLSquare())) {
4314 if (failed(parser.parseCommaSeparatedList([&]() {
4315 if (parser.parseAttribute(gangAttrs.emplace_back()))
4316 return failure();
4317 return success();
4318 })))
4319 return failure();
4320 if (parser.parseRSquare())
4321 return failure();
4322 needCommaBeforeOperands = true;
4323 }
4324
4325 if (needCommaBeforeOperands && failed(parser.parseComma()))
4326 return failure();
4327
4328 if (failed(parser.parseCommaSeparatedList([&]() {
4329 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4330 parser.parseColon() ||
4331 parser.parseAttribute(gangDimAttrs.emplace_back()))
4332 return failure();
4333 if (succeeded(parser.parseOptionalLSquare())) {
4334 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4335 parser.parseRSquare())
4336 return failure();
4337 } else {
4338 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4339 parser.getContext(), mlir::acc::DeviceType::None));
4340 }
4341 return success();
4342 })))
4343 return failure();
4344
4345 if (failed(parser.parseRParen()))
4346 return failure();
4347
4348 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4349 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4350 gangDimDeviceTypes =
4351 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4352
4353 return success();
4354}
4355
4357 std::optional<mlir::ArrayAttr> gang,
4358 std::optional<mlir::ArrayAttr> gangDim,
4359 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4360
4361 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4362 gang->size() == 1) {
4363 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4364 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4365 return;
4366 }
4367
4368 p << "(";
4369
4370 printDeviceTypes(p, gang);
4371
4372 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4373 p << ", ";
4374
4375 if (hasDeviceTypeValues(gangDimDeviceTypes))
4376 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4377 [&](const auto &pair) {
4378 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4379 p << std::get<0>(pair);
4380 printSingleDeviceType(p, std::get<1>(pair));
4381 });
4382
4383 p << ")";
4384}
4385
4386static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4387 mlir::ArrayAttr &deviceTypes) {
4389 // Keyword only
4390 if (failed(parser.parseOptionalLParen())) {
4391 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4392 parser.getContext(), mlir::acc::DeviceType::None));
4393 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4394 return success();
4395 }
4396
4397 // Parse device type attributes
4398 if (succeeded(parser.parseOptionalLSquare())) {
4399 if (failed(parser.parseCommaSeparatedList([&]() {
4400 if (parser.parseAttribute(attributes.emplace_back()))
4401 return failure();
4402 return success();
4403 })))
4404 return failure();
4405 if (parser.parseRSquare() || parser.parseRParen())
4406 return failure();
4407 }
4408 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4409 return success();
4410}
4411
4412static void
4414 std::optional<mlir::ArrayAttr> deviceTypes) {
4415
4416 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4417 auto deviceTypeAttr =
4418 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4419 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4420 return;
4421 }
4422
4423 if (!hasDeviceTypeValues(deviceTypes))
4424 return;
4425
4426 p << "([";
4427 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4428 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4429 p << dTypeAttr;
4430 });
4431 p << "])";
4432}
4433
4434bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4435
4436bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4437 return hasDeviceType(getWorker(), deviceType);
4438}
4439
4440bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4441
4442bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4443 return hasDeviceType(getVector(), deviceType);
4444}
4445
4446bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4447
4448bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4449 return hasDeviceType(getSeq(), deviceType);
4450}
4451
4452std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4453RoutineOp::getBindNameValue() {
4454 return getBindNameValue(mlir::acc::DeviceType::None);
4455}
4456
4457std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4458RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4459 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4460 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4461 return std::nullopt;
4462 }
4463
4464 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4465 auto attr = (*getBindIdName())[*pos];
4466 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4467 assert(symbolRefAttr && "expected SymbolRef");
4468 return symbolRefAttr;
4469 }
4470
4471 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4472 auto attr = (*getBindStrName())[*pos];
4473 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4474 assert(stringAttr && "expected String");
4475 return stringAttr;
4476 }
4477
4478 return std::nullopt;
4479}
4480
4481bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4482
4483bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4484 return hasDeviceType(getGang(), deviceType);
4485}
4486
4487std::optional<int64_t> RoutineOp::getGangDimValue() {
4488 return getGangDimValue(mlir::acc::DeviceType::None);
4489}
4490
4491std::optional<int64_t>
4492RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4493 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4494 return std::nullopt;
4495 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4496 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4497 return intAttr.getInt();
4498 }
4499 return std::nullopt;
4500}
4501
4502void RoutineOp::addSeq(MLIRContext *context,
4503 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4504 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4505 effectiveDeviceTypes));
4506}
4507
4508void RoutineOp::addVector(MLIRContext *context,
4509 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4510 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4511 effectiveDeviceTypes));
4512}
4513
4514void RoutineOp::addWorker(MLIRContext *context,
4515 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4516 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4517 effectiveDeviceTypes));
4518}
4519
4520void RoutineOp::addGang(MLIRContext *context,
4521 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4522 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4523 effectiveDeviceTypes));
4524}
4525
4526void RoutineOp::addGang(MLIRContext *context,
4527 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4528 uint64_t val) {
4531
4532 if (getGangDimAttr())
4533 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4534 if (getGangDimDeviceTypeAttr())
4535 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4536
4537 assert(dimValues.size() == deviceTypes.size());
4538
4539 if (effectiveDeviceTypes.empty()) {
4540 dimValues.push_back(
4541 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4542 deviceTypes.push_back(
4543 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4544 } else {
4545 for (DeviceType dt : effectiveDeviceTypes) {
4546 dimValues.push_back(
4547 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4548 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4549 }
4550 }
4551 assert(dimValues.size() == deviceTypes.size());
4552
4553 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4554 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4555}
4556
4557void RoutineOp::addBindStrName(MLIRContext *context,
4558 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4559 mlir::StringAttr val) {
4560 unsigned before = getBindStrNameDeviceTypeAttr()
4561 ? getBindStrNameDeviceTypeAttr().size()
4562 : 0;
4563
4564 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4565 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4566 unsigned after = getBindStrNameDeviceTypeAttr().size();
4567
4569 if (getBindStrNameAttr())
4570 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4571 for (unsigned i = 0; i < after - before; ++i)
4572 vals.push_back(val);
4573
4574 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4575}
4576
4577void RoutineOp::addBindIDName(MLIRContext *context,
4578 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4579 mlir::SymbolRefAttr val) {
4580 unsigned before =
4581 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4582
4583 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4584 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4585 unsigned after = getBindIdNameDeviceTypeAttr().size();
4586
4588 if (getBindIdNameAttr())
4589 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4590 for (unsigned i = 0; i < after - before; ++i)
4591 vals.push_back(val);
4592
4593 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4594}
4595
4596//===----------------------------------------------------------------------===//
4597// InitOp
4598//===----------------------------------------------------------------------===//
4599
4600LogicalResult acc::InitOp::verify() {
4601 Operation *currOp = *this;
4602 while ((currOp = currOp->getParentOp()))
4603 if (isComputeOperation(currOp))
4604 return emitOpError("cannot be nested in a compute operation");
4605 return success();
4606}
4607
4608void acc::InitOp::addDeviceType(MLIRContext *context,
4609 mlir::acc::DeviceType deviceType) {
4611 if (getDeviceTypesAttr())
4612 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4613
4614 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4615 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4616}
4617
4618//===----------------------------------------------------------------------===//
4619// ShutdownOp
4620//===----------------------------------------------------------------------===//
4621
4622LogicalResult acc::ShutdownOp::verify() {
4623 Operation *currOp = *this;
4624 while ((currOp = currOp->getParentOp()))
4625 if (isComputeOperation(currOp))
4626 return emitOpError("cannot be nested in a compute operation");
4627 return success();
4628}
4629
4630void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4631 mlir::acc::DeviceType deviceType) {
4633 if (getDeviceTypesAttr())
4634 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4635
4636 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4637 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4638}
4639
4640//===----------------------------------------------------------------------===//
4641// SetOp
4642//===----------------------------------------------------------------------===//
4643
4644LogicalResult acc::SetOp::verify() {
4645 Operation *currOp = *this;
4646 while ((currOp = currOp->getParentOp()))
4647 if (isComputeOperation(currOp))
4648 return emitOpError("cannot be nested in a compute operation");
4649 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4650 return emitOpError("at least one default_async, device_num, or device_type "
4651 "operand must appear");
4652 return success();
4653}
4654
4655//===----------------------------------------------------------------------===//
4656// UpdateOp
4657//===----------------------------------------------------------------------===//
4658
4659LogicalResult acc::UpdateOp::verify() {
4660 // At least one of host or device should have a value.
4661 if (getDataClauseOperands().empty())
4662 return emitError("at least one value must be present in dataOperands");
4663
4665 getAsyncOperandsDeviceTypeAttr(),
4666 "async")))
4667 return failure();
4668
4670 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4671 getWaitOperandsDeviceTypeAttr(), "wait")))
4672 return failure();
4673
4675 return failure();
4676
4677 for (mlir::Value operand : getDataClauseOperands())
4678 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4679 operand.getDefiningOp()))
4680 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4681 "as defining op");
4682
4683 return success();
4684}
4685
4686unsigned UpdateOp::getNumDataOperands() {
4687 return getDataClauseOperands().size();
4688}
4689
4690Value UpdateOp::getDataOperand(unsigned i) {
4691 unsigned numOptional = getAsyncOperands().size();
4692 numOptional += getIfCond() ? 1 : 0;
4693 return getOperand(getWaitOperands().size() + numOptional + i);
4694}
4695
4696void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4697 MLIRContext *context) {
4698 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4699}
4700
4701bool UpdateOp::hasAsyncOnly() {
4702 return hasAsyncOnly(mlir::acc::DeviceType::None);
4703}
4704
4705bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4706 return hasDeviceType(getAsyncOnly(), deviceType);
4707}
4708
4709mlir::Value UpdateOp::getAsyncValue() {
4710 return getAsyncValue(mlir::acc::DeviceType::None);
4711}
4712
4713mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4715 return {};
4716
4717 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4718 return getAsyncOperands()[*pos];
4719
4720 return {};
4721}
4722
4723bool UpdateOp::hasWaitOnly() {
4724 return hasWaitOnly(mlir::acc::DeviceType::None);
4725}
4726
4727bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4728 return hasDeviceType(getWaitOnly(), deviceType);
4729}
4730
4731mlir::Operation::operand_range UpdateOp::getWaitValues() {
4732 return getWaitValues(mlir::acc::DeviceType::None);
4733}
4734
4736UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4738 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4739 getHasWaitDevnum(), deviceType);
4740}
4741
4742mlir::Value UpdateOp::getWaitDevnum() {
4743 return getWaitDevnum(mlir::acc::DeviceType::None);
4744}
4745
4746mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4747 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4748 getWaitOperandsSegments(), getHasWaitDevnum(),
4749 deviceType);
4750}
4751
4752void UpdateOp::addAsyncOnly(MLIRContext *context,
4753 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4754 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4755 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4756}
4757
4758void UpdateOp::addAsyncOperand(
4759 MLIRContext *context, mlir::Value newValue,
4760 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4761 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4762 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4763 getAsyncOperandsMutable()));
4764}
4765
4766void UpdateOp::addWaitOnly(MLIRContext *context,
4767 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4768 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4769 effectiveDeviceTypes));
4770}
4771
4772void UpdateOp::addWaitOperands(
4773 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4774 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4775
4777 if (getWaitOperandsSegments())
4778 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4779
4780 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4781 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4782 getWaitOperandsMutable(), segments));
4783 setWaitOperandsSegments(segments);
4784
4786 if (getHasWaitDevnumAttr())
4787 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4788 hasDevnums.insert(
4789 hasDevnums.end(),
4790 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4791 mlir::BoolAttr::get(context, hasDevnum));
4792 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4793}
4794
4795//===----------------------------------------------------------------------===//
4796// WaitOp
4797//===----------------------------------------------------------------------===//
4798
4799LogicalResult acc::WaitOp::verify() {
4800 // The async attribute represent the async clause without value. Therefore the
4801 // attribute and operand cannot appear at the same time.
4802 if (getAsyncOperand() && getAsync())
4803 return emitError("async attribute cannot appear with asyncOperand");
4804
4805 if (getWaitDevnum() && getWaitOperands().empty())
4806 return emitError("wait_devnum cannot appear without waitOperands");
4807
4808 return success();
4809}
4810
4811#define GET_OP_CLASSES
4812#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4813
4814#define GET_ATTRDEF_CLASSES
4815#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4816
4817#define GET_TYPEDEF_CLASSES
4818#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4819
4820//===----------------------------------------------------------------------===//
4821// acc dialect utilities
4822//===----------------------------------------------------------------------===//
4823
4826 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4828 accDataClauseOp)
4829 .Case<ACC_DATA_ENTRY_OPS>(
4830 [&](auto entry) { return entry.getVarPtr(); })
4831 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4832 [&](auto exit) { return exit.getVarPtr(); })
4833 .Default([&](mlir::Operation *) {
4835 })};
4836 return varPtr;
4837}
4838
4840 auto varPtr{
4842 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4843 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4844 return varPtr;
4845}
4846
4848 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4849 .Case<ACC_DATA_ENTRY_OPS>(
4850 [&](auto entry) { return entry.getVarType(); })
4851 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4852 [&](auto exit) { return exit.getVarType(); })
4853 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4854 return varType;
4855}
4856
4859 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4861 accDataClauseOp)
4862 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4863 [&](auto dataClause) { return dataClause.getAccPtr(); })
4864 .Default([&](mlir::Operation *) {
4866 })};
4867 return accPtr;
4868}
4869
4871 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4873 [&](auto dataClause) { return dataClause.getAccVar(); })
4874 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4875 return accPtr;
4876}
4877
4879 auto varPtrPtr{
4881 .Case<ACC_DATA_ENTRY_OPS>(
4882 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4883 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4884 return varPtrPtr;
4885}
4886
4891 accDataClauseOp)
4892 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4894 dataClause.getBounds().begin(), dataClause.getBounds().end());
4895 })
4896 .Default([&](mlir::Operation *) {
4898 })};
4899 return bounds;
4900}
4901
4905 accDataClauseOp)
4906 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4908 dataClause.getAsyncOperands().begin(),
4909 dataClause.getAsyncOperands().end());
4910 })
4911 .Default([&](mlir::Operation *) {
4913 });
4914}
4915
4916mlir::ArrayAttr
4919 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4920 return dataClause.getAsyncOperandsDeviceTypeAttr();
4921 })
4922 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4923}
4924
4925mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4928 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4929 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4930}
4931
4932std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4933 auto name{
4935 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4936 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4937 return {};
4938 })};
4939 return name;
4940}
4941
4942std::optional<mlir::acc::DataClause>
4944 auto dataClause{
4946 accDataEntryOp)
4947 .Case<ACC_DATA_ENTRY_OPS>(
4948 [&](auto entry) { return entry.getDataClause(); })
4949 .Default([&](mlir::Operation *) { return std::nullopt; })};
4950 return dataClause;
4951}
4952
4954 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4955 .Case<ACC_DATA_ENTRY_OPS>(
4956 [&](auto entry) { return entry.getImplicit(); })
4957 .Default([&](mlir::Operation *) { return false; })};
4958 return implicit;
4959}
4960
4962 auto dataOperands{
4965 [&](auto entry) { return entry.getDataClauseOperands(); })
4966 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4967 return dataOperands;
4968}
4969
4972 auto dataOperands{
4975 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4976 .Default([&](mlir::Operation *) { return nullptr; })};
4977 return dataOperands;
4978}
4979
4980mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
4981 auto recipe{
4983 .Case<ACC_DATA_ENTRY_OPS>(
4984 [&](auto entry) { return entry.getRecipeAttr(); })
4985 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
4986 return recipe;
4987}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4356
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1179
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3107
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1725
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4211
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:780
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1193
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:560
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2221
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:773
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:713
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:544
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:682
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2232
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2137
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:486
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4413
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2916
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2472
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3123
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4118
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:620
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2426
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:504
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:602
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:636
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3491
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1680
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2268
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1807
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:628
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:691
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:515
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:528
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2007
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:406
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:667
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3522
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4386
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4295
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2120
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2295
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2411
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2074
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2387
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:754
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2935
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1460
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2456
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2051
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:646
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1694
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2368
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:490
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3062
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2306
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:725
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:580
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1735
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4177
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2057
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2492
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4265
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:53
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4870
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4839
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4858
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:4943
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:4971
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4888
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:4961
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4932
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:4953
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:4980
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4903
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:166
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4878
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:203
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4925
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4847
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4825
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4917
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.