MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248
249 bool isDeviceData(Type pointer, Value var) const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
253 }
254};
255
256struct LLVMPointerPointerLikeModel
257 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
259 Type getElementType(Type pointer) const { return Type(); }
260
261 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType) const {
264 // For LLVM pointers, we need the valueType to determine what to load
265 if (!valueType)
266 return {};
267
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
269 }
270
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
272 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
274 return true;
275 }
276};
277
278struct MemrefAddressOfGlobalModel
279 : public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op) const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
284 }
285};
286
287struct MemrefGlobalVariableModel
288 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
289 memref::GlobalOp> {
290 bool isConstant(Operation *op) const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
293 }
294
295 Region *getInitRegion(Operation *op) const {
296 // GlobalOp uses attributes for initialization, not regions
297 return nullptr;
298 }
299
300 bool isDeviceData(Operation *op) const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
304 }
305};
306
307struct GPULaunchOffloadRegionModel
308 : public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
311 return cast<gpu::LaunchOp>(op).getBody();
312 }
313};
314
315/// Helper function for any of the times we need to modify an ArrayAttr based on
316/// a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
318/// list is empty).
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
321 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
325
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
329
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
332
333 return mlir::ArrayAttr::get(context, deviceTypes);
334}
335
336/// Helper function for any of the times we need to add operands that are
337/// affected by a device type list. Returns a new ArrayAttr with all of the
338/// existingDeviceTypes, plus the effective new ones (or an added none, if the
339/// new list is empty). Additionally, adds the arguments to the argCollection
340/// the correct number of times. This will also update a 'segments' array, even
341/// if it won't be used.
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
344 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
345 mlir::MutableOperandRange argCollection,
346 llvm::SmallVector<int32_t> &segments) {
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
350
351 if (newDeviceTypes.empty()) {
352 argCollection.append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
356 }
357
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
362 }
363
364 return mlir::ArrayAttr::get(context, deviceTypes);
365}
366
367/// Overload for when the 'segments' aren't needed.
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
370 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
371 mlir::MutableOperandRange argCollection) {
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
376}
377} // namespace
378
379//===----------------------------------------------------------------------===//
380// OpenACC operations
381//===----------------------------------------------------------------------===//
382
383void OpenACCDialect::initialize() {
384 addOperations<
385#define GET_OP_LIST
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
387 >();
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
391 >();
392 addTypes<
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
395 >();
396
397 // By attaching interfaces here, we make the OpenACC dialect dependent on
398 // the other dialects. This is probably better than having dialects like LLVM
399 // and memref be dependent on OpenACC.
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
401 *getContext());
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
405 *getContext());
406
407 // Attach operation interfaces
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
409 *getContext());
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
412}
413
414//===----------------------------------------------------------------------===//
415// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
416// acc.kernel_environment / acc.data / acc.host_data / acc.loop
417//===----------------------------------------------------------------------===//
418
419/// Generic helper for single-region OpenACC ops that execute their body once
420/// and then return to the parent operation with their results (if any).
421static void
423 RegionBranchPoint point,
425 if (point.isParent()) {
426 regions.push_back(RegionSuccessor(&region));
427 return;
428 }
429
430 regions.push_back(RegionSuccessor::parent());
431}
432
434 RegionSuccessor successor) {
435 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
436}
437
438void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
440 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
441 regions);
442}
443
444ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
445 return getSingleRegionSuccessorInputs(getOperation(), successor);
446}
447
448void ParallelOp::getSuccessorRegions(
450 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
451 regions);
452}
453
454ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
455 return getSingleRegionSuccessorInputs(getOperation(), successor);
456}
457
458void SerialOp::getSuccessorRegions(RegionBranchPoint point,
460 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
461 regions);
462}
463
464ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
465 return getSingleRegionSuccessorInputs(getOperation(), successor);
466}
467
468void DataOp::getSuccessorRegions(RegionBranchPoint point,
470 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
471 regions);
472}
473
474ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
475 return getSingleRegionSuccessorInputs(getOperation(), successor);
476}
477
478void HostDataOp::getSuccessorRegions(
480 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
481 regions);
482}
483
484ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
485 return getSingleRegionSuccessorInputs(getOperation(), successor);
486}
487
488void LoopOp::getSuccessorRegions(RegionBranchPoint point,
490 // Unstructured loops: the body may contain arbitrary CFG and early exits.
491 // At the RegionBranch level, only model entry into the body and exit to the
492 // parent; any backedges are represented inside the region CFG.
493 if (getUnstructured()) {
494 if (point.isParent()) {
495 regions.push_back(RegionSuccessor(&getRegion()));
496 return;
497 }
498 regions.push_back(RegionSuccessor::parent());
499 return;
500 }
501
502 // Structured loops: model a loop-shaped region graph similar to scf.for.
503 regions.push_back(RegionSuccessor(&getRegion()));
504 regions.push_back(RegionSuccessor::parent());
505}
506
507ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
508 return getSingleRegionSuccessorInputs(getOperation(), successor);
509}
510
511//===----------------------------------------------------------------------===//
512// RegionBranchTerminatorOpInterface
513//===----------------------------------------------------------------------===//
514
516TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
517 // `acc.terminator` does not forward operands.
518 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
519}
520
521//===----------------------------------------------------------------------===//
522// device_type support helpers
523//===----------------------------------------------------------------------===//
524
525static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
527}
528
529static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
530 mlir::acc::DeviceType deviceType) {
531 if (!hasDeviceTypeValues(arrayAttr))
532 return false;
533
534 for (auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
537 return true;
538 }
539
540 return false;
541}
542
544 std::optional<mlir::ArrayAttr> deviceTypes) {
545 if (!hasDeviceTypeValues(deviceTypes))
546 return;
547
548 p << "[";
549 llvm::interleaveComma(*deviceTypes, p,
550 [&](mlir::Attribute attr) { p << attr; });
551 p << "]";
552}
553
554static std::optional<unsigned> findSegment(ArrayAttr segments,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
561 ++segmentIdx;
562 }
563 return std::nullopt;
564}
565
567getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
569 std::optional<llvm::ArrayRef<int32_t>> segments,
570 mlir::acc::DeviceType deviceType) {
571 if (!arrayAttr)
572 return range.take_front(0);
573 if (auto pos = findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
578 }
579 return range.take_front(0);
580}
581
582static mlir::Value
583getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
585 std::optional<llvm::ArrayRef<int32_t>> segments,
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
588 if (!hasDeviceTypeValues(deviceTypeAttr))
589 return {};
590 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
592 return getValuesFromSegments(deviceTypeAttr, operands, segments,
593 deviceType)
594 .front();
595 return {};
596}
597
599getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
601 std::optional<llvm::ArrayRef<int32_t>> segments,
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
604 auto range =
605 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
606 if (range.empty())
607 return range;
608 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1); // first value is devnum
613 }
614 }
615 return range;
616}
617
618template <typename Op>
619static LogicalResult checkWaitAndAsyncConflict(Op op) {
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
621 ++dtypeInt) {
622 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
623
624 // The asyncOnly attribute represent the async clause without value.
625 // Therefore the attribute and operand cannot appear at the same time.
626 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
627 op.hasAsyncOnly(dtype))
628 return op.emitError(
629 "asyncOnly attribute cannot appear with asyncOperand");
630
631 // The wait attribute represent the wait clause without values. Therefore
632 // the attribute and operands cannot appear at the same time.
633 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
634 op.hasWaitOnly(dtype))
635 return op.emitError("wait attribute cannot appear with waitOperands");
636 }
637 return success();
638}
639
640template <typename Op>
641static LogicalResult checkVarAndVarType(Op op) {
642 if (!op.getVar())
643 return op.emitError("must have var operand");
644
645 // A variable must have a type that is either pointer-like or mappable.
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.emitError("var must be mappable or pointer-like");
649
650 // When it is a pointer-like type, the varType must capture the target type.
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.emitError("varType must capture the element type of var");
654
655 return success();
656}
657
658template <typename Op>
659static LogicalResult checkVarAndAccVar(Op op) {
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.emitError("input and output types must match");
662
663 return success();
664}
665
666template <typename Op>
667static LogicalResult checkNoModifier(Op op) {
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.emitError("no data clause modifiers are allowed");
670 return success();
671}
672
673template <typename Op>
674static LogicalResult
675checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
677 return op.emitError(
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
680
681 return success();
682}
683
684template <typename OpT, typename RecipeOpT>
685static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
686 // Mappable types do not need a recipe because it is possible to generate one
687 // from its API. Reject reductions though because no API is available for them
688 // at this time.
689 if (mlir::acc::isMappableType(op.getVar().getType()) &&
690 !std::is_same_v<OpT, acc::ReductionOp>)
691 return success();
692
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
694 if (!operandRecipe)
695 return op->emitOpError() << "recipe expected for " << operandName;
696
697 auto decl =
699 if (!decl)
700 return op->emitOpError()
701 << "expected symbol reference " << operandRecipe << " to point to a "
702 << operandName << " declaration";
703 return success();
704}
705
706static ParseResult parseVar(mlir::OpAsmParser &parser,
708 // Either `var` or `varPtr` keyword is required.
709 if (failed(parser.parseOptionalKeyword("varPtr"))) {
710 if (failed(parser.parseKeyword("var")))
711 return failure();
712 }
713 if (failed(parser.parseLParen()))
714 return failure();
715 if (failed(parser.parseOperand(var)))
716 return failure();
717
718 return success();
719}
720
722 mlir::Value var) {
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
724 p << "varPtr(";
725 else
726 p << "var(";
727 p.printOperand(var);
728}
729
730static ParseResult parseAccVar(mlir::OpAsmParser &parser,
732 mlir::Type &accVarType) {
733 // Either `accVar` or `accPtr` keyword is required.
734 if (failed(parser.parseOptionalKeyword("accPtr"))) {
735 if (failed(parser.parseKeyword("accVar")))
736 return failure();
737 }
738 if (failed(parser.parseLParen()))
739 return failure();
740 if (failed(parser.parseOperand(var)))
741 return failure();
742 if (failed(parser.parseColon()))
743 return failure();
744 if (failed(parser.parseType(accVarType)))
745 return failure();
746 if (failed(parser.parseRParen()))
747 return failure();
748
749 return success();
750}
751
753 mlir::Value accVar, mlir::Type accVarType) {
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
755 p << "accPtr(";
756 else
757 p << "accVar(";
758 p.printOperand(accVar);
759 p << " : ";
760 p.printType(accVarType);
761 p << ")";
762}
763
764static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
765 mlir::Type &varPtrType,
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.parseType(varPtrType)))
768 return failure();
769 if (failed(parser.parseRParen()))
770 return failure();
771
772 if (succeeded(parser.parseOptionalKeyword("varType"))) {
773 if (failed(parser.parseLParen()))
774 return failure();
775 mlir::Type varType;
776 if (failed(parser.parseType(varType)))
777 return failure();
778 varTypeAttr = mlir::TypeAttr::get(varType);
779 if (failed(parser.parseRParen()))
780 return failure();
781 } else {
782 // Set `varType` from the element type of the type of `varPtr`.
783 if (auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
784 Type elementType = ptrTy.getElementType();
785 // Opaque pointers (e.g. !llvm.ptr) have no element type; fall back to
786 // using varPtrType itself so that the attribute is always valid.
787 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
788 } else {
789 varTypeAttr = mlir::TypeAttr::get(varPtrType);
790 }
791 }
792
793 return success();
794}
795
797 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
798 p.printType(varPtrType);
799 p << ")";
800
801 // Print the `varType` only if it differs from the element type of
802 // `varPtr`'s type.
803 mlir::Type varType = varTypeAttr.getValue();
804 mlir::Type typeToCheckAgainst =
805 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
806 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
807 : varPtrType;
808 // Opaque pointers (e.g. !llvm.ptr) have no element type; use varPtrType as
809 // the baseline so that the inferred varType is not redundantly printed.
810 if (!typeToCheckAgainst)
811 typeToCheckAgainst = varPtrType;
812 if (typeToCheckAgainst != varType) {
813 p << " varType(";
814 p.printType(varType);
815 p << ")";
816 }
817}
818
819static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
820 mlir::SymbolRefAttr &recipeAttr) {
821 if (failed(parser.parseAttribute(recipeAttr)))
822 return failure();
823 return success();
824}
825
827 mlir::SymbolRefAttr recipeAttr) {
828 p << recipeAttr;
829}
830
831//===----------------------------------------------------------------------===//
832// DataBoundsOp
833//===----------------------------------------------------------------------===//
834LogicalResult acc::DataBoundsOp::verify() {
835 auto extent = getExtent();
836 auto upperbound = getUpperbound();
837 if (!extent && !upperbound)
838 return emitError("expected extent or upperbound.");
839 return success();
840}
841
842//===----------------------------------------------------------------------===//
843// PrivateOp
844//===----------------------------------------------------------------------===//
845LogicalResult acc::PrivateOp::verify() {
846 if (getDataClause() != acc::DataClause::acc_private)
847 return emitError(
848 "data clause associated with private operation must match its intent");
849 if (failed(checkVarAndVarType(*this)))
850 return failure();
851 if (failed(checkNoModifier(*this)))
852 return failure();
853 if (failed(
855 return failure();
856 return success();
857}
858
859//===----------------------------------------------------------------------===//
860// FirstprivateOp
861//===----------------------------------------------------------------------===//
862LogicalResult acc::FirstprivateOp::verify() {
863 if (getDataClause() != acc::DataClause::acc_firstprivate)
864 return emitError("data clause associated with firstprivate operation must "
865 "match its intent");
866 if (failed(checkVarAndVarType(*this)))
867 return failure();
868 if (failed(checkNoModifier(*this)))
869 return failure();
871 *this, "firstprivate")))
872 return failure();
873 return success();
874}
875
876//===----------------------------------------------------------------------===//
877// ReductionOp
878//===----------------------------------------------------------------------===//
879LogicalResult acc::ReductionOp::verify() {
880 if (getDataClause() != acc::DataClause::acc_reduction)
881 return emitError("data clause associated with reduction operation must "
882 "match its intent");
883 if (failed(checkVarAndVarType(*this)))
884 return failure();
885 if (failed(checkNoModifier(*this)))
886 return failure();
888 *this, "reduction")))
889 return failure();
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// DevicePtrOp
895//===----------------------------------------------------------------------===//
896LogicalResult acc::DevicePtrOp::verify() {
897 if (getDataClause() != acc::DataClause::acc_deviceptr)
898 return emitError("data clause associated with deviceptr operation must "
899 "match its intent");
900 if (failed(checkVarAndVarType(*this)))
901 return failure();
902 if (failed(checkVarAndAccVar(*this)))
903 return failure();
904 if (failed(checkNoModifier(*this)))
905 return failure();
906 return success();
907}
908
909//===----------------------------------------------------------------------===//
910// PresentOp
911//===----------------------------------------------------------------------===//
912LogicalResult acc::PresentOp::verify() {
913 if (getDataClause() != acc::DataClause::acc_present)
914 return emitError(
915 "data clause associated with present operation must match its intent");
916 if (failed(checkVarAndVarType(*this)))
917 return failure();
918 if (failed(checkVarAndAccVar(*this)))
919 return failure();
920 if (failed(checkNoModifier(*this)))
921 return failure();
922 return success();
923}
924
925//===----------------------------------------------------------------------===//
926// CopyinOp
927//===----------------------------------------------------------------------===//
928LogicalResult acc::CopyinOp::verify() {
929 // Test for all clauses this operation can be decomposed from:
930 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
931 getDataClause() != acc::DataClause::acc_copyin_readonly &&
932 getDataClause() != acc::DataClause::acc_copy &&
933 getDataClause() != acc::DataClause::acc_reduction)
934 return emitError(
935 "data clause associated with copyin operation must match its intent"
936 " or specify original clause this operation was decomposed from");
937 if (failed(checkVarAndVarType(*this)))
938 return failure();
939 if (failed(checkVarAndAccVar(*this)))
940 return failure();
941 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
942 acc::DataClauseModifier::always |
943 acc::DataClauseModifier::capture)))
944 return failure();
945 return success();
946}
947
948bool acc::CopyinOp::isCopyinReadonly() {
949 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
950 acc::bitEnumContainsAny(getModifiers(),
951 acc::DataClauseModifier::readonly);
952}
953
954//===----------------------------------------------------------------------===//
955// CreateOp
956//===----------------------------------------------------------------------===//
957LogicalResult acc::CreateOp::verify() {
958 // Test for all clauses this operation can be decomposed from:
959 if (getDataClause() != acc::DataClause::acc_create &&
960 getDataClause() != acc::DataClause::acc_create_zero &&
961 getDataClause() != acc::DataClause::acc_copyout &&
962 getDataClause() != acc::DataClause::acc_copyout_zero)
963 return emitError(
964 "data clause associated with create operation must match its intent"
965 " or specify original clause this operation was decomposed from");
966 if (failed(checkVarAndVarType(*this)))
967 return failure();
968 if (failed(checkVarAndAccVar(*this)))
969 return failure();
970 // this op is the entry part of copyout, so it also needs to allow all
971 // modifiers allowed on copyout.
972 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
973 acc::DataClauseModifier::always |
974 acc::DataClauseModifier::capture)))
975 return failure();
976 return success();
977}
978
979bool acc::CreateOp::isCreateZero() {
980 // The zero modifier is encoded in the data clause.
981 return getDataClause() == acc::DataClause::acc_create_zero ||
982 getDataClause() == acc::DataClause::acc_copyout_zero ||
983 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
984}
985
986//===----------------------------------------------------------------------===//
987// NoCreateOp
988//===----------------------------------------------------------------------===//
989LogicalResult acc::NoCreateOp::verify() {
990 if (getDataClause() != acc::DataClause::acc_no_create)
991 return emitError("data clause associated with no_create operation must "
992 "match its intent");
993 if (failed(checkVarAndVarType(*this)))
994 return failure();
995 if (failed(checkVarAndAccVar(*this)))
996 return failure();
997 if (failed(checkNoModifier(*this)))
998 return failure();
999 return success();
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// AttachOp
1004//===----------------------------------------------------------------------===//
1005LogicalResult acc::AttachOp::verify() {
1006 if (getDataClause() != acc::DataClause::acc_attach)
1007 return emitError(
1008 "data clause associated with attach operation must match its intent");
1009 if (failed(checkVarAndVarType(*this)))
1010 return failure();
1011 if (failed(checkVarAndAccVar(*this)))
1012 return failure();
1013 if (failed(checkNoModifier(*this)))
1014 return failure();
1015 return success();
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// DeclareDeviceResidentOp
1020//===----------------------------------------------------------------------===//
1021
1022LogicalResult acc::DeclareDeviceResidentOp::verify() {
1023 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1024 return emitError("data clause associated with device_resident operation "
1025 "must match its intent");
1026 if (failed(checkVarAndVarType(*this)))
1027 return failure();
1028 if (failed(checkVarAndAccVar(*this)))
1029 return failure();
1030 if (failed(checkNoModifier(*this)))
1031 return failure();
1032 return success();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// DeclareLinkOp
1037//===----------------------------------------------------------------------===//
1038
1039LogicalResult acc::DeclareLinkOp::verify() {
1040 if (getDataClause() != acc::DataClause::acc_declare_link)
1041 return emitError(
1042 "data clause associated with link operation must match its intent");
1043 if (failed(checkVarAndVarType(*this)))
1044 return failure();
1045 if (failed(checkVarAndAccVar(*this)))
1046 return failure();
1047 if (failed(checkNoModifier(*this)))
1048 return failure();
1049 return success();
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// CopyoutOp
1054//===----------------------------------------------------------------------===//
1055LogicalResult acc::CopyoutOp::verify() {
1056 // Test for all clauses this operation can be decomposed from:
1057 if (getDataClause() != acc::DataClause::acc_copyout &&
1058 getDataClause() != acc::DataClause::acc_copyout_zero &&
1059 getDataClause() != acc::DataClause::acc_copy &&
1060 getDataClause() != acc::DataClause::acc_reduction)
1061 return emitError(
1062 "data clause associated with copyout operation must match its intent"
1063 " or specify original clause this operation was decomposed from");
1064 if (!getVar() || !getAccVar())
1065 return emitError("must have both host and device pointers");
1066 if (failed(checkVarAndVarType(*this)))
1067 return failure();
1068 if (failed(checkVarAndAccVar(*this)))
1069 return failure();
1070 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1071 acc::DataClauseModifier::always |
1072 acc::DataClauseModifier::capture)))
1073 return failure();
1074 return success();
1075}
1076
1077bool acc::CopyoutOp::isCopyoutZero() {
1078 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1079 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1080}
1081
1082//===----------------------------------------------------------------------===//
1083// DeleteOp
1084//===----------------------------------------------------------------------===//
1085LogicalResult acc::DeleteOp::verify() {
1086 // Test for all clauses this operation can be decomposed from:
1087 if (getDataClause() != acc::DataClause::acc_delete &&
1088 getDataClause() != acc::DataClause::acc_create &&
1089 getDataClause() != acc::DataClause::acc_create_zero &&
1090 getDataClause() != acc::DataClause::acc_copyin &&
1091 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1092 getDataClause() != acc::DataClause::acc_present &&
1093 getDataClause() != acc::DataClause::acc_no_create &&
1094 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1095 getDataClause() != acc::DataClause::acc_declare_link)
1096 return emitError(
1097 "data clause associated with delete operation must match its intent"
1098 " or specify original clause this operation was decomposed from");
1099 if (!getAccVar())
1100 return emitError("must have device pointer");
1101 // This op is the exit part of copyin and create - thus allow all modifiers
1102 // allowed on either case.
1103 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1104 acc::DataClauseModifier::readonly |
1105 acc::DataClauseModifier::always |
1106 acc::DataClauseModifier::capture)))
1107 return failure();
1108 return success();
1109}
1110
1111//===----------------------------------------------------------------------===//
1112// DetachOp
1113//===----------------------------------------------------------------------===//
1114LogicalResult acc::DetachOp::verify() {
1115 // Test for all clauses this operation can be decomposed from:
1116 if (getDataClause() != acc::DataClause::acc_detach &&
1117 getDataClause() != acc::DataClause::acc_attach)
1118 return emitError(
1119 "data clause associated with detach operation must match its intent"
1120 " or specify original clause this operation was decomposed from");
1121 if (!getAccVar())
1122 return emitError("must have device pointer");
1123 if (failed(checkNoModifier(*this)))
1124 return failure();
1125 return success();
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// HostOp
1130//===----------------------------------------------------------------------===//
1131LogicalResult acc::UpdateHostOp::verify() {
1132 // Test for all clauses this operation can be decomposed from:
1133 if (getDataClause() != acc::DataClause::acc_update_host &&
1134 getDataClause() != acc::DataClause::acc_update_self)
1135 return emitError(
1136 "data clause associated with host operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1138 if (!getVar() || !getAccVar())
1139 return emitError("must have both host and device pointers");
1140 if (failed(checkVarAndVarType(*this)))
1141 return failure();
1142 if (failed(checkVarAndAccVar(*this)))
1143 return failure();
1144 if (failed(checkNoModifier(*this)))
1145 return failure();
1146 return success();
1147}
1148
1149//===----------------------------------------------------------------------===//
1150// DeviceOp
1151//===----------------------------------------------------------------------===//
1152LogicalResult acc::UpdateDeviceOp::verify() {
1153 // Test for all clauses this operation can be decomposed from:
1154 if (getDataClause() != acc::DataClause::acc_update_device)
1155 return emitError(
1156 "data clause associated with device operation must match its intent"
1157 " or specify original clause this operation was decomposed from");
1158 if (failed(checkVarAndVarType(*this)))
1159 return failure();
1160 if (failed(checkVarAndAccVar(*this)))
1161 return failure();
1162 if (failed(checkNoModifier(*this)))
1163 return failure();
1164 return success();
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// UseDeviceOp
1169//===----------------------------------------------------------------------===//
1170LogicalResult acc::UseDeviceOp::verify() {
1171 // Test for all clauses this operation can be decomposed from:
1172 if (getDataClause() != acc::DataClause::acc_use_device)
1173 return emitError(
1174 "data clause associated with use_device operation must match its intent"
1175 " or specify original clause this operation was decomposed from");
1176 if (failed(checkVarAndVarType(*this)))
1177 return failure();
1178 if (failed(checkVarAndAccVar(*this)))
1179 return failure();
1180 if (failed(checkNoModifier(*this)))
1181 return failure();
1182 return success();
1183}
1184
1185//===----------------------------------------------------------------------===//
1186// CacheOp
1187//===----------------------------------------------------------------------===//
1188LogicalResult acc::CacheOp::verify() {
1189 // Test for all clauses this operation can be decomposed from:
1190 if (getDataClause() != acc::DataClause::acc_cache &&
1191 getDataClause() != acc::DataClause::acc_cache_readonly)
1192 return emitError(
1193 "data clause associated with cache operation must match its intent"
1194 " or specify original clause this operation was decomposed from");
1195 if (failed(checkVarAndVarType(*this)))
1196 return failure();
1197 if (failed(checkVarAndAccVar(*this)))
1198 return failure();
1199 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1200 return failure();
1201 return success();
1202}
1203
1204bool acc::CacheOp::isCacheReadonly() {
1205 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1206 acc::bitEnumContainsAny(getModifiers(),
1207 acc::DataClauseModifier::readonly);
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// Data entry/exit operations - getEffects implementations
1212//===----------------------------------------------------------------------===//
1213
1214// This function returns true iff the given operation is enclosed
1215// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1216// It is quite alike acc::getEnclosingComputeOp() utility,
1217// but we cannot use it here.
1221
1222/// Helper to add an effect on an operand, referenced by its mutable range.
1223template <typename EffectTy>
1226 &effects,
1227 MutableOperandRange operand) {
1228 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1229 effects.emplace_back(EffectTy::get(), &operand[i]);
1230}
1231
1232/// Helper to add an effect on a result value.
1233template <typename EffectTy>
1236 &effects,
1237 Value result) {
1238 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1239}
1240
1241// PrivateOp: accVar result write.
1242void acc::PrivateOp::getEffects(
1244 &effects) {
1245 // If acc.private is enclosed into a compute operation,
1246 // then it denotes the device side privatization, hence
1247 // it does not access the CurrentDeviceIdResource.
1248 if (!isEnclosedIntoComputeOp(getOperation()))
1249 effects.emplace_back(MemoryEffects::Read::get(),
1251 // TODO: should this be MemoryEffects::Allocate?
1253}
1254
1255// FirstprivateOp: var read, accVar result write.
1256void acc::FirstprivateOp::getEffects(
1258 &effects) {
1259 // If acc.firstprivate is enclosed into a compute operation,
1260 // then it denotes the device side privatization, hence
1261 // it does not access the CurrentDeviceIdResource.
1262 if (!isEnclosedIntoComputeOp(getOperation()))
1263 effects.emplace_back(MemoryEffects::Read::get(),
1265 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1267}
1268
1269// ReductionOp: var read, accVar result write.
1270void acc::ReductionOp::getEffects(
1272 &effects) {
1273 // If acc.reduction is enclosed into a compute operation,
1274 // then it denotes the device side reduction, hence
1275 // it does not access the CurrentDeviceIdResource.
1276 if (!isEnclosedIntoComputeOp(getOperation()))
1277 effects.emplace_back(MemoryEffects::Read::get(),
1279 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1281}
1282
1283// DevicePtrOp: RuntimeCounters read.
1284void acc::DevicePtrOp::getEffects(
1286 &effects) {
1287 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1288 effects.emplace_back(MemoryEffects::Read::get(),
1290}
1291
1292// PresentOp: RuntimeCounters read+write.
1293void acc::PresentOp::getEffects(
1295 &effects) {
1296 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1297 effects.emplace_back(MemoryEffects::Write::get(),
1299 effects.emplace_back(MemoryEffects::Read::get(),
1301}
1302
1303// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1304void acc::CopyinOp::getEffects(
1306 &effects) {
1307 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1308 effects.emplace_back(MemoryEffects::Write::get(),
1310 effects.emplace_back(MemoryEffects::Read::get(),
1312 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1314}
1315
1316// CreateOp: RuntimeCounters read+write, accVar result write.
1317void acc::CreateOp::getEffects(
1319 &effects) {
1320 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1321 effects.emplace_back(MemoryEffects::Write::get(),
1323 effects.emplace_back(MemoryEffects::Read::get(),
1325 // TODO: should this be MemoryEffects::Allocate?
1327}
1328
1329// NoCreateOp: RuntimeCounters read+write.
1330void acc::NoCreateOp::getEffects(
1332 &effects) {
1333 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1334 effects.emplace_back(MemoryEffects::Write::get(),
1336 effects.emplace_back(MemoryEffects::Read::get(),
1338}
1339
1340// AttachOp: RuntimeCounters read+write, var read.
1341void acc::AttachOp::getEffects(
1343 &effects) {
1344 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1345 effects.emplace_back(MemoryEffects::Write::get(),
1347 effects.emplace_back(MemoryEffects::Read::get(),
1349 // TODO: should we also add MemoryEffects::Write?
1350 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1351}
1352
1353// GetDevicePtrOp: RuntimeCounters read.
1354void acc::GetDevicePtrOp::getEffects(
1356 &effects) {
1357 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1358 effects.emplace_back(MemoryEffects::Read::get(),
1360}
1361
1362// UpdateDeviceOp: var read, accVar result write.
1363void acc::UpdateDeviceOp::getEffects(
1365 &effects) {
1366 effects.emplace_back(MemoryEffects::Read::get(),
1368 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1370}
1371
1372// UseDeviceOp: RuntimeCounters read.
1373void acc::UseDeviceOp::getEffects(
1375 &effects) {
1376 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1377 effects.emplace_back(MemoryEffects::Read::get(),
1379}
1380
1381// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1382void acc::DeclareDeviceResidentOp::getEffects(
1384 &effects) {
1385 effects.emplace_back(MemoryEffects::Write::get(),
1387 effects.emplace_back(MemoryEffects::Read::get(),
1389 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1390}
1391
1392// DeclareLinkOp: RuntimeCounters write, var read.
1393void acc::DeclareLinkOp::getEffects(
1395 &effects) {
1396 effects.emplace_back(MemoryEffects::Write::get(),
1398 effects.emplace_back(MemoryEffects::Read::get(),
1400 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1401}
1402
1403// CacheOp: NoMemoryEffect
1404void acc::CacheOp::getEffects(
1406 &effects) {}
1407
1408// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1409void acc::CopyoutOp::getEffects(
1411 &effects) {
1412 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1413 effects.emplace_back(MemoryEffects::Write::get(),
1415 effects.emplace_back(MemoryEffects::Read::get(),
1417 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1418 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1419}
1420
1421// DeleteOp: RuntimeCounters read+write, accVar read.
1422void acc::DeleteOp::getEffects(
1424 &effects) {
1425 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1426 effects.emplace_back(MemoryEffects::Write::get(),
1428 effects.emplace_back(MemoryEffects::Read::get(),
1430 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1431}
1432
1433// DetachOp: RuntimeCounters read+write, accVar read.
1434void acc::DetachOp::getEffects(
1436 &effects) {
1437 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1438 effects.emplace_back(MemoryEffects::Write::get(),
1440 effects.emplace_back(MemoryEffects::Read::get(),
1442 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1443}
1444
1445// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1446void acc::UpdateHostOp::getEffects(
1448 &effects) {
1449 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1450 effects.emplace_back(MemoryEffects::Write::get(),
1452 effects.emplace_back(MemoryEffects::Read::get(),
1454 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1455 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1456}
1457
1458template <typename StructureOp>
1459static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1460 unsigned nRegions = 1) {
1461
1463 for (unsigned i = 0; i < nRegions; ++i)
1464 regions.push_back(state.addRegion());
1465
1466 for (Region *region : regions)
1467 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1468 return failure();
1469
1470 return success();
1471}
1472
1473namespace {
1474/// Pattern to remove operation without region that have constant false `ifCond`
1475/// and remove the condition from the operation if the `ifCond` is a true
1476/// constant.
1477template <typename OpTy>
1478struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1479 using OpRewritePattern<OpTy>::OpRewritePattern;
1480
1481 LogicalResult matchAndRewrite(OpTy op,
1482 PatternRewriter &rewriter) const override {
1483 // Early return if there is no condition.
1484 Value ifCond = op.getIfCond();
1485 if (!ifCond)
1486 return failure();
1487
1488 IntegerAttr constAttr;
1489 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1490 return failure();
1491 if (constAttr.getInt())
1492 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1493 else
1494 rewriter.eraseOp(op);
1495
1496 return success();
1497 }
1498};
1499
1500/// Replaces the given op with the contents of the given single-block region,
1501/// using the operands of the block terminator to replace operation results.
1502static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1503 Region &region, ValueRange blockArgs = {}) {
1504 assert(region.hasOneBlock() && "expected single-block region");
1505 Block *block = &region.front();
1506 Operation *terminator = block->getTerminator();
1507 ValueRange results = terminator->getOperands();
1508 rewriter.inlineBlockBefore(block, op, blockArgs);
1509 rewriter.replaceOp(op, results);
1510 rewriter.eraseOp(terminator);
1511}
1512
1513/// Pattern to remove operation with region that have constant false `ifCond`
1514/// and remove the condition from the operation if the `ifCond` is constant
1515/// true.
1516template <typename OpTy>
1517struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1518 using OpRewritePattern<OpTy>::OpRewritePattern;
1519
1520 LogicalResult matchAndRewrite(OpTy op,
1521 PatternRewriter &rewriter) const override {
1522 // Early return if there is no condition.
1523 Value ifCond = op.getIfCond();
1524 if (!ifCond)
1525 return failure();
1526
1527 IntegerAttr constAttr;
1528 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1529 return failure();
1530 if (constAttr.getInt())
1531 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1532 else
1533 replaceOpWithRegion(rewriter, op, op.getRegion());
1534
1535 return success();
1536 }
1537};
1538
1539//===----------------------------------------------------------------------===//
1540// Recipe Region Helpers
1541//===----------------------------------------------------------------------===//
1542
1543/// Create and populate an init region for privatization recipes.
1544/// Returns success if the region is populated, failure otherwise.
1545/// Sets needsFree to indicate if the allocated memory requires deallocation.
1546/// The `hostVar` is the original host variable used to derive
1547/// language-specific metadata via `genPrivateVariableInfo`.
1548/// The `varInfo` output parameter is set to the variable info produced.
1549static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1550 Region &initRegion, Value hostVar,
1551 StringRef varName, ValueRange bounds,
1552 bool &needsFree,
1553 acc::VariableInfoAttr &varInfo) {
1554 Type varType = hostVar.getType();
1555
1556 // Create init block with arguments: original value + bounds
1557 SmallVector<Type> argTypes{varType};
1558 SmallVector<Location> argLocs{loc};
1559 for (Value bound : bounds) {
1560 argTypes.push_back(bound.getType());
1561 argLocs.push_back(loc);
1562 }
1563
1564 Block *initBlock = builder.createBlock(&initRegion);
1565 initBlock->addArguments(argTypes, argLocs);
1566 builder.setInsertionPointToStart(initBlock);
1567
1568 Value privatizedValue;
1569
1570 // Get the block argument that represents the original variable
1571 Value blockArgVar = initBlock->getArgument(0);
1572
1573 // Generate init region body based on variable type
1574 if (isa<MappableType>(varType)) {
1575 auto mappableTy = cast<MappableType>(varType);
1576 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1577 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1578 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1579 privatizedValue = mappableTy.generatePrivateInit(
1580 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1581 if (!privatizedValue)
1582 return failure();
1583 } else {
1584 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1585 auto pointerLikeTy = cast<PointerLikeType>(varType);
1586 // Use PointerLikeType's allocation API with the block argument
1587 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1588 blockArgVar, needsFree);
1589 if (!privatizedValue)
1590 return failure();
1591 }
1592
1593 // Add yield operation to init block
1594 acc::YieldOp::create(builder, loc, privatizedValue);
1595
1596 return success();
1597}
1598
1599/// Create and populate a copy region for firstprivate recipes.
1600/// Returns success if the region is populated, failure otherwise.
1601/// TODO: Handle MappableType - it does not yet have a copy API.
1602static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1603 Region &copyRegion, Type varType,
1604 ValueRange bounds) {
1605 // Create copy block with arguments: original value + privatized value +
1606 // bounds
1607 SmallVector<Type> copyArgTypes{varType, varType};
1608 SmallVector<Location> copyArgLocs{loc, loc};
1609 for (Value bound : bounds) {
1610 copyArgTypes.push_back(bound.getType());
1611 copyArgLocs.push_back(loc);
1612 }
1613
1614 Block *copyBlock = builder.createBlock(&copyRegion);
1615 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1616 builder.setInsertionPointToStart(copyBlock);
1617
1618 bool isMappable = isa<MappableType>(varType);
1619 bool isPointerLike = isa<PointerLikeType>(varType);
1620 // TODO: Handle MappableType - it does not yet have a copy API.
1621 // Otherwise, for now just fallback to pointer-like behavior.
1622 if (isMappable && !isPointerLike)
1623 return failure();
1624
1625 // Generate copy region body based on variable type
1626 if (isPointerLike) {
1627 auto pointerLikeTy = cast<PointerLikeType>(varType);
1628 Value originalArg = copyBlock->getArgument(0);
1629 Value privatizedArg = copyBlock->getArgument(1);
1630
1631 // Generate copy operation using PointerLikeType interface
1632 if (!pointerLikeTy.genCopy(
1633 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1634 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1635 return failure();
1636 }
1637
1638 // Add terminator to copy block
1639 acc::TerminatorOp::create(builder, loc);
1640
1641 return success();
1642}
1643
1644/// Create and populate a destroy region for privatization recipes.
1645/// Returns success if the region is populated, failure otherwise.
1646/// The `varInfo` carries language-specific metadata produced by
1647/// `createInitRegion`.
1648static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1649 Region &destroyRegion, Type varType,
1650 Value allocRes, ValueRange bounds,
1651 acc::VariableInfoAttr varInfo) {
1652 // Create destroy block with arguments: original value + privatized value +
1653 // bounds
1654 SmallVector<Type> destroyArgTypes{varType, varType};
1655 SmallVector<Location> destroyArgLocs{loc, loc};
1656 for (Value bound : bounds) {
1657 destroyArgTypes.push_back(bound.getType());
1658 destroyArgLocs.push_back(loc);
1659 }
1660
1661 Block *destroyBlock = builder.createBlock(&destroyRegion);
1662 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1663 builder.setInsertionPointToStart(destroyBlock);
1664
1665 auto varToFree =
1666 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1667 if (isa<MappableType>(varType)) {
1668 auto mappableTy = cast<MappableType>(varType);
1669 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1670 varInfo))
1671 return failure();
1672 } else {
1673 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1674 auto pointerLikeTy = cast<PointerLikeType>(varType);
1675 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1676 return failure();
1677 }
1678
1679 acc::TerminatorOp::create(builder, loc);
1680 return success();
1681}
1682
1683} // namespace
1684
1685//===----------------------------------------------------------------------===//
1686// PrivateRecipeOp
1687//===----------------------------------------------------------------------===//
1688
1690 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1691 Type type, bool verifyYield, bool optional = false) {
1692 if (optional && region.empty())
1693 return success();
1694
1695 if (region.empty())
1696 return op->emitOpError() << "expects non-empty " << regionName << " region";
1697 Block &firstBlock = region.front();
1698 if (firstBlock.getNumArguments() < 1 ||
1699 firstBlock.getArgument(0).getType() != type)
1700 return op->emitOpError() << "expects " << regionName
1701 << " region first "
1702 "argument of the "
1703 << regionType << " type";
1704
1705 if (verifyYield) {
1706 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1707 if (yieldOp.getOperands().size() != 1 ||
1708 yieldOp.getOperands().getTypes()[0] != type)
1709 return op->emitOpError() << "expects " << regionName
1710 << " region to "
1711 "yield a value of the "
1712 << regionType << " type";
1713 }
1714 }
1715 return success();
1716}
1717
1718LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1719 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1720 "privatization", "init", getType(),
1721 /*verifyYield=*/false)))
1722 return failure();
1724 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1725 /*verifyYield=*/false, /*optional=*/true)))
1726 return failure();
1727 return success();
1728}
1729
1730std::optional<PrivateRecipeOp>
1731PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1732 StringRef recipeName, Value hostVar,
1733 StringRef varName, ValueRange bounds) {
1734 Type varType = hostVar.getType();
1735
1736 // First, validate that we can handle this variable type
1737 bool isMappable = isa<MappableType>(varType);
1738 bool isPointerLike = isa<PointerLikeType>(varType);
1739
1740 // Unsupported type
1741 if (!isMappable && !isPointerLike)
1742 return std::nullopt;
1743
1744 OpBuilder::InsertionGuard guard(builder);
1745
1746 // Create the recipe operation first so regions have proper parent context
1747 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1748
1749 // Populate the init region
1750 bool needsFree = false;
1751 acc::VariableInfoAttr varInfo;
1752 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1753 varName, bounds, needsFree, varInfo))) {
1754 recipe.erase();
1755 return std::nullopt;
1756 }
1757
1758 // Only create destroy region if the allocation needs deallocation
1759 if (needsFree) {
1760 // Extract the allocated value from the init block's yield operation
1761 auto yieldOp =
1762 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1763 Value allocRes = yieldOp.getOperand(0);
1764
1765 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1766 varType, allocRes, bounds, varInfo))) {
1767 recipe.erase();
1768 return std::nullopt;
1769 }
1770 }
1771
1772 return recipe;
1773}
1774
1775std::optional<PrivateRecipeOp>
1776PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1777 StringRef recipeName,
1778 FirstprivateRecipeOp firstprivRecipe) {
1779 // Create the private.recipe op with the same type as the firstprivate.recipe.
1780 OpBuilder::InsertionGuard guard(builder);
1781 auto varType = firstprivRecipe.getType();
1782 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1783
1784 // Clone the init region
1785 IRMapping mapping;
1786 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1787
1788 // Clone destroy region if the firstprivate.recipe has one.
1789 if (!firstprivRecipe.getDestroyRegion().empty()) {
1790 IRMapping mapping;
1791 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1792 mapping);
1793 }
1794 return recipe;
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// FirstprivateRecipeOp
1799//===----------------------------------------------------------------------===//
1800
1801LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1802 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1803 "privatization", "init", getType(),
1804 /*verifyYield=*/false)))
1805 return failure();
1806
1807 if (getCopyRegion().empty())
1808 return emitOpError() << "expects non-empty copy region";
1809
1810 Block &firstBlock = getCopyRegion().front();
1811 if (firstBlock.getNumArguments() < 2 ||
1812 firstBlock.getArgument(0).getType() != getType())
1813 return emitOpError() << "expects copy region with two arguments of the "
1814 "privatization type";
1815
1816 if (getDestroyRegion().empty())
1817 return success();
1818
1819 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1820 "privatization", "destroy",
1821 getType(), /*verifyYield=*/false)))
1822 return failure();
1823
1824 return success();
1825}
1826
1827std::optional<FirstprivateRecipeOp>
1828FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1829 StringRef recipeName, Value hostVar,
1830 StringRef varName, ValueRange bounds) {
1831 Type varType = hostVar.getType();
1832
1833 // First, validate that we can handle this variable type
1834 bool isMappable = isa<MappableType>(varType);
1835 bool isPointerLike = isa<PointerLikeType>(varType);
1836
1837 // Unsupported type
1838 if (!isMappable && !isPointerLike)
1839 return std::nullopt;
1840
1841 OpBuilder::InsertionGuard guard(builder);
1842
1843 // Create the recipe operation first so regions have proper parent context
1844 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1845
1846 // Populate the init region
1847 bool needsFree = false;
1848 acc::VariableInfoAttr varInfo;
1849 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1850 varName, bounds, needsFree, varInfo))) {
1851 recipe.erase();
1852 return std::nullopt;
1853 }
1854
1855 // Populate the copy region
1856 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1857 bounds))) {
1858 recipe.erase();
1859 return std::nullopt;
1860 }
1861
1862 // Only create destroy region if the allocation needs deallocation
1863 if (needsFree) {
1864 // Extract the allocated value from the init block's yield operation
1865 auto yieldOp =
1866 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1867 Value allocRes = yieldOp.getOperand(0);
1868
1869 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1870 varType, allocRes, bounds, varInfo))) {
1871 recipe.erase();
1872 return std::nullopt;
1873 }
1874 }
1875
1876 return recipe;
1877}
1878
1879//===----------------------------------------------------------------------===//
1880// ReductionRecipeOp
1881//===----------------------------------------------------------------------===//
1882
1883LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1884 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1885 "init", getType(),
1886 /*verifyYield=*/false)))
1887 return failure();
1888
1889 if (getCombinerRegion().empty())
1890 return emitOpError() << "expects non-empty combiner region";
1891
1892 Block &reductionBlock = getCombinerRegion().front();
1893 if (reductionBlock.getNumArguments() < 2 ||
1894 reductionBlock.getArgument(0).getType() != getType() ||
1895 reductionBlock.getArgument(1).getType() != getType())
1896 return emitOpError() << "expects combiner region with the first two "
1897 << "arguments of the reduction type";
1898
1899 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1900 if (yieldOp.getOperands().size() != 1 ||
1901 yieldOp.getOperands().getTypes()[0] != getType())
1902 return emitOpError() << "expects combiner region to yield a value "
1903 "of the reduction type";
1904 }
1905
1906 return success();
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// ParallelOp
1911//===----------------------------------------------------------------------===//
1912
1913/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1914template <typename Op>
1915static LogicalResult checkDataOperands(Op op,
1916 const mlir::ValueRange &operands) {
1917 for (mlir::Value operand : operands)
1918 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1919 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1920 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1921 operand.getDefiningOp()))
1922 return op.emitError(
1923 "expect data entry/exit operation or acc.getdeviceptr "
1924 "as defining op");
1925 return success();
1926}
1927
1928template <typename OpT, typename RecipeOpT>
1929static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1930 const mlir::ValueRange &operands,
1931 llvm::StringRef operandName) {
1933 for (mlir::Value operand : operands) {
1934 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1935 return accConstructOp->emitOpError()
1936 << "expected " << operandName << " as defining op";
1937 if (!set.insert(operand).second)
1938 return accConstructOp->emitOpError()
1939 << operandName << " operand appears more than once";
1940 }
1941 return success();
1942}
1943
1944unsigned ParallelOp::getNumDataOperands() {
1945 return getReductionOperands().size() + getPrivateOperands().size() +
1946 getFirstprivateOperands().size() + getDataClauseOperands().size();
1947}
1948
1949Value ParallelOp::getDataOperand(unsigned i) {
1950 unsigned numOptional = getAsyncOperands().size();
1951 numOptional += getNumGangs().size();
1952 numOptional += getNumWorkers().size();
1953 numOptional += getVectorLength().size();
1954 numOptional += getIfCond() ? 1 : 0;
1955 numOptional += getSelfCond() ? 1 : 0;
1956 return getOperand(getWaitOperands().size() + numOptional + i);
1957}
1958
1959template <typename Op>
1960static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1961 ArrayAttr deviceTypes,
1962 llvm::StringRef keyword) {
1963 if (!operands.empty() &&
1964 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
1965 return op.emitOpError() << keyword << " operands count must match "
1966 << keyword << " device_type count";
1967 return success();
1968}
1969
1970template <typename Op>
1972 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1973 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1974 std::size_t numOperandsInSegments = 0;
1975 std::size_t nbOfSegments = 0;
1976
1977 if (segments) {
1978 for (auto segCount : segments.asArrayRef()) {
1979 if (maxInSegment != 0 && segCount > maxInSegment)
1980 return op.emitOpError() << keyword << " expects a maximum of "
1981 << maxInSegment << " values per segment";
1982 numOperandsInSegments += segCount;
1983 ++nbOfSegments;
1984 }
1985 }
1986
1987 if ((numOperandsInSegments != operands.size()) ||
1988 (!deviceTypes && !operands.empty()))
1989 return op.emitOpError()
1990 << keyword << " operand count does not match count in segments";
1991 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1992 return op.emitOpError()
1993 << keyword << " segment count does not match device_type count";
1994 return success();
1995}
1996
1997LogicalResult acc::ParallelOp::verify() {
1998 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1999 mlir::acc::PrivateRecipeOp>(
2000 *this, getPrivateOperands(), "private")))
2001 return failure();
2002 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2003 mlir::acc::FirstprivateRecipeOp>(
2004 *this, getFirstprivateOperands(), "firstprivate")))
2005 return failure();
2006 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2007 mlir::acc::ReductionRecipeOp>(
2008 *this, getReductionOperands(), "reduction")))
2009 return failure();
2010
2012 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2013 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2014 return failure();
2015
2017 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2018 getWaitOperandsDeviceTypeAttr(), "wait")))
2019 return failure();
2020
2021 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2022 getNumWorkersDeviceTypeAttr(),
2023 "num_workers")))
2024 return failure();
2025
2026 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2027 getVectorLengthDeviceTypeAttr(),
2028 "vector_length")))
2029 return failure();
2030
2032 getAsyncOperandsDeviceTypeAttr(),
2033 "async")))
2034 return failure();
2035
2037 return failure();
2038
2039 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2040}
2041
2042static mlir::Value
2043getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2045 mlir::acc::DeviceType deviceType) {
2046 if (!arrayAttr)
2047 return {};
2048 if (auto pos = findSegment(*arrayAttr, deviceType))
2049 return range[*pos];
2050 return {};
2051}
2052
2053bool acc::ParallelOp::hasAsyncOnly() {
2054 return hasAsyncOnly(mlir::acc::DeviceType::None);
2055}
2056
2057bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2058 return hasDeviceType(getAsyncOnly(), deviceType);
2059}
2060
2061mlir::Value acc::ParallelOp::getAsyncValue() {
2062 return getAsyncValue(mlir::acc::DeviceType::None);
2063}
2064
2065mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2067 getAsyncOperands(), deviceType);
2068}
2069
2070mlir::Value acc::ParallelOp::getNumWorkersValue() {
2071 return getNumWorkersValue(mlir::acc::DeviceType::None);
2072}
2073
2075acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2076 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2077 deviceType);
2078}
2079
2080mlir::Value acc::ParallelOp::getVectorLengthValue() {
2081 return getVectorLengthValue(mlir::acc::DeviceType::None);
2082}
2083
2085acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2086 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2087 getVectorLength(), deviceType);
2088}
2089
2090mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2091 return getNumGangsValues(mlir::acc::DeviceType::None);
2092}
2093
2095ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2096 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2097 getNumGangsSegments(), deviceType);
2098}
2099
2100bool acc::ParallelOp::hasWaitOnly() {
2101 return hasWaitOnly(mlir::acc::DeviceType::None);
2102}
2103
2104bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2105 return hasDeviceType(getWaitOnly(), deviceType);
2106}
2107
2108mlir::Operation::operand_range ParallelOp::getWaitValues() {
2109 return getWaitValues(mlir::acc::DeviceType::None);
2110}
2111
2113ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2115 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2116 getHasWaitDevnum(), deviceType);
2117}
2118
2119mlir::Value ParallelOp::getWaitDevnum() {
2120 return getWaitDevnum(mlir::acc::DeviceType::None);
2121}
2122
2123mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2124 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2125 getWaitOperandsSegments(), getHasWaitDevnum(),
2126 deviceType);
2127}
2128
2129void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2130 mlir::OperationState &odsState,
2131 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2132 mlir::ValueRange vectorLength,
2133 mlir::ValueRange asyncOperands,
2134 mlir::ValueRange waitOperands, mlir::Value ifCond,
2135 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2136 mlir::ValueRange gangPrivateOperands,
2137 mlir::ValueRange gangFirstPrivateOperands,
2138 mlir::ValueRange dataClauseOperands) {
2139 ParallelOp::build(
2140 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2141 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2142 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2143 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2144 /*numGangsDeviceType=*/nullptr, numWorkers,
2145 /*numWorkersDeviceType=*/nullptr, vectorLength,
2146 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2147 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2148 gangFirstPrivateOperands, dataClauseOperands,
2149 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2150}
2151
2152void acc::ParallelOp::addNumWorkersOperand(
2153 MLIRContext *context, mlir::Value newValue,
2154 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2155 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2156 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2157 getNumWorkersMutable()));
2158}
2159void acc::ParallelOp::addVectorLengthOperand(
2160 MLIRContext *context, mlir::Value newValue,
2161 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2162 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2163 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2164 getVectorLengthMutable()));
2165}
2166
2167void acc::ParallelOp::addAsyncOnly(
2168 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2169 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2170 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2171}
2172
2173void acc::ParallelOp::addAsyncOperand(
2174 MLIRContext *context, mlir::Value newValue,
2175 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2176 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2177 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2178 getAsyncOperandsMutable()));
2179}
2180
2181void acc::ParallelOp::addNumGangsOperands(
2182 MLIRContext *context, mlir::ValueRange newValues,
2183 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2185 if (getNumGangsSegments())
2186 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2187
2188 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2189 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2190 getNumGangsMutable(), segments));
2191
2192 setNumGangsSegments(segments);
2193}
2194void acc::ParallelOp::addWaitOnly(
2195 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2196 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2197 effectiveDeviceTypes));
2198}
2199void acc::ParallelOp::addWaitOperands(
2200 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2201 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2202
2204 if (getWaitOperandsSegments())
2205 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2206
2207 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2208 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2209 getWaitOperandsMutable(), segments));
2210 setWaitOperandsSegments(segments);
2211
2213 if (getHasWaitDevnumAttr())
2214 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2215 hasDevnums.insert(
2216 hasDevnums.end(),
2217 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2218 mlir::BoolAttr::get(context, hasDevnum));
2219 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2220}
2221
2222void acc::ParallelOp::addPrivatization(MLIRContext *context,
2223 mlir::acc::PrivateOp op,
2224 mlir::acc::PrivateRecipeOp recipe) {
2225 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2226 getPrivateOperandsMutable().append(op.getResult());
2227}
2228
2229void acc::ParallelOp::addFirstPrivatization(
2230 MLIRContext *context, mlir::acc::FirstprivateOp op,
2231 mlir::acc::FirstprivateRecipeOp recipe) {
2232 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2233 getFirstprivateOperandsMutable().append(op.getResult());
2234}
2235
2236void acc::ParallelOp::addReduction(MLIRContext *context,
2237 mlir::acc::ReductionOp op,
2238 mlir::acc::ReductionRecipeOp recipe) {
2239 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2240 getReductionOperandsMutable().append(op.getResult());
2241}
2242
2243static ParseResult parseNumGangs(
2244 mlir::OpAsmParser &parser,
2246 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2247 mlir::DenseI32ArrayAttr &segments) {
2250
2251 do {
2252 if (failed(parser.parseLBrace()))
2253 return failure();
2254
2255 int32_t crtOperandsSize = operands.size();
2256 if (failed(parser.parseCommaSeparatedList(
2258 if (parser.parseOperand(operands.emplace_back()) ||
2259 parser.parseColonType(types.emplace_back()))
2260 return failure();
2261 return success();
2262 })))
2263 return failure();
2264 seg.push_back(operands.size() - crtOperandsSize);
2265
2266 if (failed(parser.parseRBrace()))
2267 return failure();
2268
2269 if (succeeded(parser.parseOptionalLSquare())) {
2270 if (parser.parseAttribute(attributes.emplace_back()) ||
2271 parser.parseRSquare())
2272 return failure();
2273 } else {
2274 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2275 parser.getContext(), mlir::acc::DeviceType::None));
2276 }
2277 } while (succeeded(parser.parseOptionalComma()));
2278
2279 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2280 attributes.end());
2281 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2282 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2283
2284 return success();
2285}
2286
2288 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2289 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2290 p << " [" << attr << "]";
2291}
2292
2294 mlir::OperandRange operands, mlir::TypeRange types,
2295 std::optional<mlir::ArrayAttr> deviceTypes,
2296 std::optional<mlir::DenseI32ArrayAttr> segments) {
2297 unsigned opIdx = 0;
2298 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2299 p << "{";
2300 llvm::interleaveComma(
2301 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2302 p << operands[opIdx] << " : " << operands[opIdx].getType();
2303 ++opIdx;
2304 });
2305 p << "}";
2306 printSingleDeviceType(p, it.value());
2307 });
2308}
2309
2311 mlir::OpAsmParser &parser,
2313 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2314 mlir::DenseI32ArrayAttr &segments) {
2317
2318 do {
2319 if (failed(parser.parseLBrace()))
2320 return failure();
2321
2322 int32_t crtOperandsSize = operands.size();
2323
2324 if (failed(parser.parseCommaSeparatedList(
2326 if (parser.parseOperand(operands.emplace_back()) ||
2327 parser.parseColonType(types.emplace_back()))
2328 return failure();
2329 return success();
2330 })))
2331 return failure();
2332
2333 seg.push_back(operands.size() - crtOperandsSize);
2334
2335 if (failed(parser.parseRBrace()))
2336 return failure();
2337
2338 if (succeeded(parser.parseOptionalLSquare())) {
2339 if (parser.parseAttribute(attributes.emplace_back()) ||
2340 parser.parseRSquare())
2341 return failure();
2342 } else {
2343 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2344 parser.getContext(), mlir::acc::DeviceType::None));
2345 }
2346 } while (succeeded(parser.parseOptionalComma()));
2347
2348 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2349 attributes.end());
2350 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2351 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2352
2353 return success();
2354}
2355
2358 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2359 std::optional<mlir::DenseI32ArrayAttr> segments) {
2360 unsigned opIdx = 0;
2361 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2362 p << "{";
2363 llvm::interleaveComma(
2364 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2365 p << operands[opIdx] << " : " << operands[opIdx].getType();
2366 ++opIdx;
2367 });
2368 p << "}";
2369 printSingleDeviceType(p, it.value());
2370 });
2371}
2372
2373static ParseResult parseWaitClause(
2374 mlir::OpAsmParser &parser,
2376 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2377 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2378 mlir::ArrayAttr &keywordOnly) {
2379 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2381
2382 bool needCommaBeforeOperands = false;
2383
2384 // Keyword only
2385 if (failed(parser.parseOptionalLParen())) {
2386 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2387 parser.getContext(), mlir::acc::DeviceType::None));
2388 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2389 return success();
2390 }
2391
2392 // Parse keyword only attributes
2393 if (succeeded(parser.parseOptionalLSquare())) {
2394 if (failed(parser.parseCommaSeparatedList([&]() {
2395 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2396 return failure();
2397 return success();
2398 })))
2399 return failure();
2400 if (parser.parseRSquare())
2401 return failure();
2402 needCommaBeforeOperands = true;
2403 }
2404
2405 if (needCommaBeforeOperands && failed(parser.parseComma()))
2406 return failure();
2407
2408 do {
2409 if (failed(parser.parseLBrace()))
2410 return failure();
2411
2412 int32_t crtOperandsSize = operands.size();
2413
2414 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2415 if (failed(parser.parseColon()))
2416 return failure();
2417 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2418 } else {
2419 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2420 }
2421
2422 if (failed(parser.parseCommaSeparatedList(
2424 if (parser.parseOperand(operands.emplace_back()) ||
2425 parser.parseColonType(types.emplace_back()))
2426 return failure();
2427 return success();
2428 })))
2429 return failure();
2430
2431 seg.push_back(operands.size() - crtOperandsSize);
2432
2433 if (failed(parser.parseRBrace()))
2434 return failure();
2435
2436 if (succeeded(parser.parseOptionalLSquare())) {
2437 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2438 parser.parseRSquare())
2439 return failure();
2440 } else {
2441 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2442 parser.getContext(), mlir::acc::DeviceType::None));
2443 }
2444 } while (succeeded(parser.parseOptionalComma()));
2445
2446 if (failed(parser.parseRParen()))
2447 return failure();
2448
2449 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2450 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2451 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2452 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2453
2454 return success();
2455}
2456
2457static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2458 if (!hasDeviceTypeValues(attrs))
2459 return false;
2460 if (attrs->size() != 1)
2461 return false;
2462 if (auto deviceTypeAttr =
2463 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2464 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2465 return false;
2466}
2467
2469 mlir::OperandRange operands, mlir::TypeRange types,
2470 std::optional<mlir::ArrayAttr> deviceTypes,
2471 std::optional<mlir::DenseI32ArrayAttr> segments,
2472 std::optional<mlir::ArrayAttr> hasDevNum,
2473 std::optional<mlir::ArrayAttr> keywordOnly) {
2474
2475 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2476 return;
2477
2478 p << "(";
2479
2480 printDeviceTypes(p, keywordOnly);
2481 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2482 p << ", ";
2483
2484 if (hasDeviceTypeValues(deviceTypes)) {
2485 unsigned opIdx = 0;
2486 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2487 p << "{";
2488 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2489 if (boolAttr && boolAttr.getValue())
2490 p << "devnum: ";
2491 llvm::interleaveComma(
2492 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2493 p << operands[opIdx] << " : " << operands[opIdx].getType();
2494 ++opIdx;
2495 });
2496 p << "}";
2497 printSingleDeviceType(p, it.value());
2498 });
2499 }
2500
2501 p << ")";
2502}
2503
2504static ParseResult parseDeviceTypeOperands(
2505 mlir::OpAsmParser &parser,
2507 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2509 if (failed(parser.parseCommaSeparatedList([&]() {
2510 if (parser.parseOperand(operands.emplace_back()) ||
2511 parser.parseColonType(types.emplace_back()))
2512 return failure();
2513 if (succeeded(parser.parseOptionalLSquare())) {
2514 if (parser.parseAttribute(attributes.emplace_back()) ||
2515 parser.parseRSquare())
2516 return failure();
2517 } else {
2518 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2519 parser.getContext(), mlir::acc::DeviceType::None));
2520 }
2521 return success();
2522 })))
2523 return failure();
2524 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2525 attributes.end());
2526 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2527 return success();
2528}
2529
2530static void
2532 mlir::OperandRange operands, mlir::TypeRange types,
2533 std::optional<mlir::ArrayAttr> deviceTypes) {
2534 if (!hasDeviceTypeValues(deviceTypes))
2535 return;
2536 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2537 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2538 printSingleDeviceType(p, std::get<0>(it));
2539 });
2540}
2541
2543 mlir::OpAsmParser &parser,
2545 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2546 mlir::ArrayAttr &keywordOnlyDeviceType) {
2547
2548 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2549 bool needCommaBeforeOperands = false;
2550
2551 if (failed(parser.parseOptionalLParen())) {
2552 // Keyword only
2553 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2554 parser.getContext(), mlir::acc::DeviceType::None));
2555 keywordOnlyDeviceType =
2556 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2557 return success();
2558 }
2559
2560 // Parse keyword only attributes
2561 if (succeeded(parser.parseOptionalLSquare())) {
2562 // Parse keyword only attributes
2563 if (failed(parser.parseCommaSeparatedList([&]() {
2564 if (parser.parseAttribute(
2565 keywordOnlyDeviceTypeAttributes.emplace_back()))
2566 return failure();
2567 return success();
2568 })))
2569 return failure();
2570 if (parser.parseRSquare())
2571 return failure();
2572 needCommaBeforeOperands = true;
2573 }
2574
2575 if (needCommaBeforeOperands && failed(parser.parseComma()))
2576 return failure();
2577
2579 if (failed(parser.parseCommaSeparatedList([&]() {
2580 if (parser.parseOperand(operands.emplace_back()) ||
2581 parser.parseColonType(types.emplace_back()))
2582 return failure();
2583 if (succeeded(parser.parseOptionalLSquare())) {
2584 if (parser.parseAttribute(attributes.emplace_back()) ||
2585 parser.parseRSquare())
2586 return failure();
2587 } else {
2588 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2589 parser.getContext(), mlir::acc::DeviceType::None));
2590 }
2591 return success();
2592 })))
2593 return failure();
2594
2595 if (failed(parser.parseRParen()))
2596 return failure();
2597
2598 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2599 attributes.end());
2600 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2601 return success();
2602}
2603
2606 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2607 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2608
2609 if (operands.begin() == operands.end() &&
2610 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2611 return;
2612 }
2613
2614 p << "(";
2615 printDeviceTypes(p, keywordOnlyDeviceTypes);
2616 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2617 hasDeviceTypeValues(deviceTypes))
2618 p << ", ";
2619 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2620 p << ")";
2621}
2622
2624 mlir::OpAsmParser &parser,
2625 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2626 mlir::Type &operandType, mlir::UnitAttr &attr) {
2627 // Keyword only
2628 if (failed(parser.parseOptionalLParen())) {
2629 attr = mlir::UnitAttr::get(parser.getContext());
2630 return success();
2631 }
2632
2634 if (failed(parser.parseOperand(op)))
2635 return failure();
2636 operand = op;
2637 if (failed(parser.parseColon()))
2638 return failure();
2639 if (failed(parser.parseType(operandType)))
2640 return failure();
2641 if (failed(parser.parseRParen()))
2642 return failure();
2643
2644 return success();
2645}
2646
2648 mlir::Operation *op,
2649 std::optional<mlir::Value> operand,
2650 mlir::Type operandType,
2651 mlir::UnitAttr attr) {
2652 if (attr)
2653 return;
2654
2655 p << "(";
2656 p.printOperand(*operand);
2657 p << " : ";
2658 p.printType(operandType);
2659 p << ")";
2660}
2661
2663 mlir::OpAsmParser &parser,
2665 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2666 // Keyword only
2667 if (failed(parser.parseOptionalLParen())) {
2668 attr = mlir::UnitAttr::get(parser.getContext());
2669 return success();
2670 }
2671
2672 if (failed(parser.parseCommaSeparatedList([&]() {
2673 if (parser.parseOperand(operands.emplace_back()))
2674 return failure();
2675 return success();
2676 })))
2677 return failure();
2678 if (failed(parser.parseColon()))
2679 return failure();
2680 if (failed(parser.parseCommaSeparatedList([&]() {
2681 if (parser.parseType(types.emplace_back()))
2682 return failure();
2683 return success();
2684 })))
2685 return failure();
2686 if (failed(parser.parseRParen()))
2687 return failure();
2688
2689 return success();
2690}
2691
2693 mlir::Operation *op,
2694 mlir::OperandRange operands,
2695 mlir::TypeRange types,
2696 mlir::UnitAttr attr) {
2697 if (attr)
2698 return;
2699
2700 p << "(";
2701 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2702 p << " : ";
2703 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2704 p << ")";
2705}
2706
2707static ParseResult
2709 mlir::acc::CombinedConstructsTypeAttr &attr) {
2710 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2711 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2712 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2713 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2714 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2715 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2716 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2717 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2718 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2719 } else {
2720 parser.emitError(parser.getCurrentLocation(),
2721 "expected compute construct name");
2722 return failure();
2723 }
2724 return success();
2725}
2726
2727static void
2729 mlir::acc::CombinedConstructsTypeAttr attr) {
2730 if (attr) {
2731 switch (attr.getValue()) {
2732 case mlir::acc::CombinedConstructsType::KernelsLoop:
2733 p << "kernels";
2734 break;
2735 case mlir::acc::CombinedConstructsType::ParallelLoop:
2736 p << "parallel";
2737 break;
2738 case mlir::acc::CombinedConstructsType::SerialLoop:
2739 p << "serial";
2740 break;
2741 };
2742 }
2743}
2744
2745//===----------------------------------------------------------------------===//
2746// SerialOp
2747//===----------------------------------------------------------------------===//
2748
2749unsigned SerialOp::getNumDataOperands() {
2750 return getReductionOperands().size() + getPrivateOperands().size() +
2751 getFirstprivateOperands().size() + getDataClauseOperands().size();
2752}
2753
2754Value SerialOp::getDataOperand(unsigned i) {
2755 unsigned numOptional = getAsyncOperands().size();
2756 numOptional += getIfCond() ? 1 : 0;
2757 numOptional += getSelfCond() ? 1 : 0;
2758 return getOperand(getWaitOperands().size() + numOptional + i);
2759}
2760
2761bool acc::SerialOp::hasAsyncOnly() {
2762 return hasAsyncOnly(mlir::acc::DeviceType::None);
2763}
2764
2765bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2766 return hasDeviceType(getAsyncOnly(), deviceType);
2767}
2768
2769mlir::Value acc::SerialOp::getAsyncValue() {
2770 return getAsyncValue(mlir::acc::DeviceType::None);
2771}
2772
2773mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2775 getAsyncOperands(), deviceType);
2776}
2777
2778bool acc::SerialOp::hasWaitOnly() {
2779 return hasWaitOnly(mlir::acc::DeviceType::None);
2780}
2781
2782bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2783 return hasDeviceType(getWaitOnly(), deviceType);
2784}
2785
2786mlir::Operation::operand_range SerialOp::getWaitValues() {
2787 return getWaitValues(mlir::acc::DeviceType::None);
2788}
2789
2791SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2793 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2794 getHasWaitDevnum(), deviceType);
2795}
2796
2797mlir::Value SerialOp::getWaitDevnum() {
2798 return getWaitDevnum(mlir::acc::DeviceType::None);
2799}
2800
2801mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2802 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2803 getWaitOperandsSegments(), getHasWaitDevnum(),
2804 deviceType);
2805}
2806
2807LogicalResult acc::SerialOp::verify() {
2808 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2809 mlir::acc::PrivateRecipeOp>(
2810 *this, getPrivateOperands(), "private")))
2811 return failure();
2812 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2813 mlir::acc::FirstprivateRecipeOp>(
2814 *this, getFirstprivateOperands(), "firstprivate")))
2815 return failure();
2816 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2817 mlir::acc::ReductionRecipeOp>(
2818 *this, getReductionOperands(), "reduction")))
2819 return failure();
2820
2822 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2823 getWaitOperandsDeviceTypeAttr(), "wait")))
2824 return failure();
2825
2827 getAsyncOperandsDeviceTypeAttr(),
2828 "async")))
2829 return failure();
2830
2832 return failure();
2833
2834 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2835}
2836
2837void acc::SerialOp::addAsyncOnly(
2838 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2839 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2840 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2841}
2842
2843void acc::SerialOp::addAsyncOperand(
2844 MLIRContext *context, mlir::Value newValue,
2845 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2846 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2847 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2848 getAsyncOperandsMutable()));
2849}
2850
2851void acc::SerialOp::addWaitOnly(
2852 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2853 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2854 effectiveDeviceTypes));
2855}
2856void acc::SerialOp::addWaitOperands(
2857 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2858 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2859
2861 if (getWaitOperandsSegments())
2862 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2863
2864 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2865 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2866 getWaitOperandsMutable(), segments));
2867 setWaitOperandsSegments(segments);
2868
2870 if (getHasWaitDevnumAttr())
2871 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2872 hasDevnums.insert(
2873 hasDevnums.end(),
2874 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2875 mlir::BoolAttr::get(context, hasDevnum));
2876 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2877}
2878
2879void acc::SerialOp::addPrivatization(MLIRContext *context,
2880 mlir::acc::PrivateOp op,
2881 mlir::acc::PrivateRecipeOp recipe) {
2882 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2883 getPrivateOperandsMutable().append(op.getResult());
2884}
2885
2886void acc::SerialOp::addFirstPrivatization(
2887 MLIRContext *context, mlir::acc::FirstprivateOp op,
2888 mlir::acc::FirstprivateRecipeOp recipe) {
2889 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2890 getFirstprivateOperandsMutable().append(op.getResult());
2891}
2892
2893void acc::SerialOp::addReduction(MLIRContext *context,
2894 mlir::acc::ReductionOp op,
2895 mlir::acc::ReductionRecipeOp recipe) {
2896 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2897 getReductionOperandsMutable().append(op.getResult());
2898}
2899
2900//===----------------------------------------------------------------------===//
2901// KernelsOp
2902//===----------------------------------------------------------------------===//
2903
2904unsigned KernelsOp::getNumDataOperands() {
2905 return getDataClauseOperands().size();
2906}
2907
2908Value KernelsOp::getDataOperand(unsigned i) {
2909 unsigned numOptional = getAsyncOperands().size();
2910 numOptional += getWaitOperands().size();
2911 numOptional += getNumGangs().size();
2912 numOptional += getNumWorkers().size();
2913 numOptional += getVectorLength().size();
2914 numOptional += getIfCond() ? 1 : 0;
2915 numOptional += getSelfCond() ? 1 : 0;
2916 return getOperand(numOptional + i);
2917}
2918
2919bool acc::KernelsOp::hasAsyncOnly() {
2920 return hasAsyncOnly(mlir::acc::DeviceType::None);
2921}
2922
2923bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2924 return hasDeviceType(getAsyncOnly(), deviceType);
2925}
2926
2927mlir::Value acc::KernelsOp::getAsyncValue() {
2928 return getAsyncValue(mlir::acc::DeviceType::None);
2929}
2930
2931mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2933 getAsyncOperands(), deviceType);
2934}
2935
2936mlir::Value acc::KernelsOp::getNumWorkersValue() {
2937 return getNumWorkersValue(mlir::acc::DeviceType::None);
2938}
2939
2941acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2942 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2943 deviceType);
2944}
2945
2946mlir::Value acc::KernelsOp::getVectorLengthValue() {
2947 return getVectorLengthValue(mlir::acc::DeviceType::None);
2948}
2949
2951acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2952 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2953 getVectorLength(), deviceType);
2954}
2955
2956mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2957 return getNumGangsValues(mlir::acc::DeviceType::None);
2958}
2959
2961KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2962 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2963 getNumGangsSegments(), deviceType);
2964}
2965
2966bool acc::KernelsOp::hasWaitOnly() {
2967 return hasWaitOnly(mlir::acc::DeviceType::None);
2968}
2969
2970bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2971 return hasDeviceType(getWaitOnly(), deviceType);
2972}
2973
2974mlir::Operation::operand_range KernelsOp::getWaitValues() {
2975 return getWaitValues(mlir::acc::DeviceType::None);
2976}
2977
2979KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2981 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2982 getHasWaitDevnum(), deviceType);
2983}
2984
2985mlir::Value KernelsOp::getWaitDevnum() {
2986 return getWaitDevnum(mlir::acc::DeviceType::None);
2987}
2988
2989mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2990 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2991 getWaitOperandsSegments(), getHasWaitDevnum(),
2992 deviceType);
2993}
2994
2995LogicalResult acc::KernelsOp::verify() {
2997 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2998 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2999 return failure();
3000
3002 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3003 getWaitOperandsDeviceTypeAttr(), "wait")))
3004 return failure();
3005
3006 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
3007 getNumWorkersDeviceTypeAttr(),
3008 "num_workers")))
3009 return failure();
3010
3011 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
3012 getVectorLengthDeviceTypeAttr(),
3013 "vector_length")))
3014 return failure();
3015
3017 getAsyncOperandsDeviceTypeAttr(),
3018 "async")))
3019 return failure();
3020
3022 return failure();
3023
3024 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3025}
3026
3027void acc::KernelsOp::addPrivatization(MLIRContext *context,
3028 mlir::acc::PrivateOp op,
3029 mlir::acc::PrivateRecipeOp recipe) {
3030 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3031 getPrivateOperandsMutable().append(op.getResult());
3032}
3033
3034void acc::KernelsOp::addFirstPrivatization(
3035 MLIRContext *context, mlir::acc::FirstprivateOp op,
3036 mlir::acc::FirstprivateRecipeOp recipe) {
3037 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3038 getFirstprivateOperandsMutable().append(op.getResult());
3039}
3040
3041void acc::KernelsOp::addReduction(MLIRContext *context,
3042 mlir::acc::ReductionOp op,
3043 mlir::acc::ReductionRecipeOp recipe) {
3044 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3045 getReductionOperandsMutable().append(op.getResult());
3046}
3047
3048void acc::KernelsOp::addNumWorkersOperand(
3049 MLIRContext *context, mlir::Value newValue,
3050 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3051 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3052 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3053 getNumWorkersMutable()));
3054}
3055
3056void acc::KernelsOp::addVectorLengthOperand(
3057 MLIRContext *context, mlir::Value newValue,
3058 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3059 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3060 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3061 getVectorLengthMutable()));
3062}
3063void acc::KernelsOp::addAsyncOnly(
3064 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3065 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3066 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3067}
3068
3069void acc::KernelsOp::addAsyncOperand(
3070 MLIRContext *context, mlir::Value newValue,
3071 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3072 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3073 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3074 getAsyncOperandsMutable()));
3075}
3076
3077void acc::KernelsOp::addNumGangsOperands(
3078 MLIRContext *context, mlir::ValueRange newValues,
3079 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3081 if (getNumGangsSegmentsAttr())
3082 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3083
3084 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3085 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3086 getNumGangsMutable(), segments));
3087
3088 setNumGangsSegments(segments);
3089}
3090
3091void acc::KernelsOp::addWaitOnly(
3092 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3093 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3094 effectiveDeviceTypes));
3095}
3096void acc::KernelsOp::addWaitOperands(
3097 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3098 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3099
3101 if (getWaitOperandsSegments())
3102 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3103
3104 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3105 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3106 getWaitOperandsMutable(), segments));
3107 setWaitOperandsSegments(segments);
3108
3110 if (getHasWaitDevnumAttr())
3111 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3112 hasDevnums.insert(
3113 hasDevnums.end(),
3114 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3115 mlir::BoolAttr::get(context, hasDevnum));
3116 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3117}
3118
3119//===----------------------------------------------------------------------===//
3120// HostDataOp
3121//===----------------------------------------------------------------------===//
3122
3123LogicalResult acc::HostDataOp::verify() {
3124 if (getDataClauseOperands().empty())
3125 return emitError("at least one operand must appear on the host_data "
3126 "operation");
3127
3129 for (mlir::Value operand : getDataClauseOperands()) {
3130 auto useDeviceOp =
3131 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3132 if (!useDeviceOp)
3133 return emitError("expect data entry operation as defining op");
3134
3135 // Check for duplicate use_device clauses
3136 if (!seenVars.insert(useDeviceOp.getVar()).second)
3137 return emitError("duplicate use_device variable");
3138 }
3139 return success();
3140}
3141
3142void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3143 MLIRContext *context) {
3144 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3145}
3146
3147//===----------------------------------------------------------------------===//
3148// LoopOp
3149//===----------------------------------------------------------------------===//
3150
3151static ParseResult parseGangValue(
3152 OpAsmParser &parser, llvm::StringRef keyword,
3155 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3156 bool &needCommaBetweenValues, bool &newValue) {
3157 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3158 if (parser.parseEqual())
3159 return failure();
3160 if (parser.parseOperand(operands.emplace_back()) ||
3161 parser.parseColonType(types.emplace_back()))
3162 return failure();
3163 attributes.push_back(gangArgType);
3164 needCommaBetweenValues = true;
3165 newValue = true;
3166 }
3167 return success();
3168}
3169
3170static ParseResult parseGangClause(
3171 OpAsmParser &parser,
3173 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3174 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3175 mlir::ArrayAttr &gangOnlyDeviceType) {
3176 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3177 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3178 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3180 bool needCommaBetweenValues = false;
3181 bool needCommaBeforeOperands = false;
3182
3183 if (failed(parser.parseOptionalLParen())) {
3184 // Gang only keyword
3185 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3186 parser.getContext(), mlir::acc::DeviceType::None));
3187 gangOnlyDeviceType =
3188 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3189 return success();
3190 }
3191
3192 // Parse gang only attributes
3193 if (succeeded(parser.parseOptionalLSquare())) {
3194 // Parse gang only attributes
3195 if (failed(parser.parseCommaSeparatedList([&]() {
3196 if (parser.parseAttribute(
3197 gangOnlyDeviceTypeAttributes.emplace_back()))
3198 return failure();
3199 return success();
3200 })))
3201 return failure();
3202 if (parser.parseRSquare())
3203 return failure();
3204 needCommaBeforeOperands = true;
3205 }
3206
3207 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3208 mlir::acc::GangArgType::Num);
3209 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3210 mlir::acc::GangArgType::Dim);
3211 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3212 parser.getContext(), mlir::acc::GangArgType::Static);
3213
3214 do {
3215 if (needCommaBeforeOperands) {
3216 needCommaBeforeOperands = false;
3217 continue;
3218 }
3219
3220 if (failed(parser.parseLBrace()))
3221 return failure();
3222
3223 int32_t crtOperandsSize = gangOperands.size();
3224 while (true) {
3225 bool newValue = false;
3226 bool needValue = false;
3227 if (needCommaBetweenValues) {
3228 if (succeeded(parser.parseOptionalComma()))
3229 needValue = true; // expect a new value after comma.
3230 else
3231 break;
3232 }
3233
3234 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3235 gangOperands, gangOperandsType,
3236 gangArgTypeAttributes, argNum,
3237 needCommaBetweenValues, newValue)))
3238 return failure();
3239 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3240 gangOperands, gangOperandsType,
3241 gangArgTypeAttributes, argDim,
3242 needCommaBetweenValues, newValue)))
3243 return failure();
3244 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3245 gangOperands, gangOperandsType,
3246 gangArgTypeAttributes, argStatic,
3247 needCommaBetweenValues, newValue)))
3248 return failure();
3249
3250 if (!newValue && needValue) {
3251 parser.emitError(parser.getCurrentLocation(),
3252 "new value expected after comma");
3253 return failure();
3254 }
3255
3256 if (!newValue)
3257 break;
3258 }
3259
3260 if (gangOperands.empty())
3261 return parser.emitError(
3262 parser.getCurrentLocation(),
3263 "expect at least one of num, dim or static values");
3264
3265 if (failed(parser.parseRBrace()))
3266 return failure();
3267
3268 if (succeeded(parser.parseOptionalLSquare())) {
3269 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3270 parser.parseRSquare())
3271 return failure();
3272 } else {
3273 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3274 parser.getContext(), mlir::acc::DeviceType::None));
3275 }
3276
3277 seg.push_back(gangOperands.size() - crtOperandsSize);
3278
3279 } while (succeeded(parser.parseOptionalComma()));
3280
3281 if (failed(parser.parseRParen()))
3282 return failure();
3283
3284 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3285 gangArgTypeAttributes.end());
3286 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3287 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3288
3290 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3291 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3292
3293 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3294 return success();
3295}
3296
3298 mlir::OperandRange operands, mlir::TypeRange types,
3299 std::optional<mlir::ArrayAttr> gangArgTypes,
3300 std::optional<mlir::ArrayAttr> deviceTypes,
3301 std::optional<mlir::DenseI32ArrayAttr> segments,
3302 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3303
3304 if (operands.begin() == operands.end() &&
3305 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3306 return;
3307 }
3308
3309 p << "(";
3310
3311 printDeviceTypes(p, gangOnlyDeviceTypes);
3312
3313 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3314 hasDeviceTypeValues(deviceTypes))
3315 p << ", ";
3316
3317 if (hasDeviceTypeValues(deviceTypes)) {
3318 unsigned opIdx = 0;
3319 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3320 p << "{";
3321 llvm::interleaveComma(
3322 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3323 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3324 (*gangArgTypes)[opIdx]);
3325 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3326 p << LoopOp::getGangNumKeyword();
3327 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3328 p << LoopOp::getGangDimKeyword();
3329 else if (gangArgTypeAttr.getValue() ==
3330 mlir::acc::GangArgType::Static)
3331 p << LoopOp::getGangStaticKeyword();
3332 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3333 ++opIdx;
3334 });
3335 p << "}";
3336 printSingleDeviceType(p, it.value());
3337 });
3338 }
3339 p << ")";
3340}
3341
3343 std::optional<mlir::ArrayAttr> segments,
3344 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3345 if (!segments)
3346 return false;
3347 for (auto attr : *segments) {
3348 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3349 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3350 return true;
3351 }
3352 return false;
3353}
3354
3355/// Check for duplicates in the DeviceType array attribute.
3356/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3357static std::optional<mlir::acc::DeviceType>
3358checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3359 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3360 if (!deviceTypes)
3361 return std::nullopt;
3362 for (auto attr : deviceTypes) {
3363 auto deviceTypeAttr =
3364 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3365 if (!deviceTypeAttr)
3366 return mlir::acc::DeviceType::None;
3367 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3368 return deviceTypeAttr.getValue();
3369 }
3370 return std::nullopt;
3371}
3372
3373LogicalResult acc::LoopOp::verify() {
3374 if (getUpperbound().size() != getStep().size())
3375 return emitError() << "number of upperbounds expected to be the same as "
3376 "number of steps";
3377
3378 if (getUpperbound().size() != getLowerbound().size())
3379 return emitError() << "number of upperbounds expected to be the same as "
3380 "number of lowerbounds";
3381
3382 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3383 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3384 return emitError() << "inclusiveUpperbound size is expected to be the same"
3385 << " as upperbound size";
3386
3387 // Check collapse
3388 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3389 return emitOpError() << "collapse device_type attr must be define when"
3390 << " collapse attr is present";
3391
3392 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3393 getCollapseAttr().getValue().size() !=
3394 getCollapseDeviceTypeAttr().getValue().size())
3395 return emitOpError() << "collapse attribute count must match collapse"
3396 << " device_type count";
3397 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3398 return emitOpError() << "duplicate device_type `"
3399 << acc::stringifyDeviceType(*duplicateDeviceType)
3400 << "` found in collapseDeviceType attribute";
3401
3402 // Check gang
3403 if (!getGangOperands().empty()) {
3404 if (!getGangOperandsArgType())
3405 return emitOpError() << "gangOperandsArgType attribute must be defined"
3406 << " when gang operands are present";
3407
3408 if (getGangOperands().size() !=
3409 getGangOperandsArgTypeAttr().getValue().size())
3410 return emitOpError() << "gangOperandsArgType attribute count must match"
3411 << " gangOperands count";
3412 }
3413 if (getGangAttr()) {
3414 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3415 return emitOpError() << "duplicate device_type `"
3416 << acc::stringifyDeviceType(*duplicateDeviceType)
3417 << "` found in gang attribute";
3418 }
3419
3421 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3422 getGangOperandsDeviceTypeAttr(), "gang")))
3423 return failure();
3424
3425 // Check worker
3426 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3427 return emitOpError() << "duplicate device_type `"
3428 << acc::stringifyDeviceType(*duplicateDeviceType)
3429 << "` found in worker attribute";
3430 if (auto duplicateDeviceType =
3431 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3432 return emitOpError() << "duplicate device_type `"
3433 << acc::stringifyDeviceType(*duplicateDeviceType)
3434 << "` found in workerNumOperandsDeviceType attribute";
3435 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3436 getWorkerNumOperandsDeviceTypeAttr(),
3437 "worker")))
3438 return failure();
3439
3440 // Check vector
3441 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3442 return emitOpError() << "duplicate device_type `"
3443 << acc::stringifyDeviceType(*duplicateDeviceType)
3444 << "` found in vector attribute";
3445 if (auto duplicateDeviceType =
3446 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3447 return emitOpError() << "duplicate device_type `"
3448 << acc::stringifyDeviceType(*duplicateDeviceType)
3449 << "` found in vectorOperandsDeviceType attribute";
3450 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3451 getVectorOperandsDeviceTypeAttr(),
3452 "vector")))
3453 return failure();
3454
3456 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3457 getTileOperandsDeviceTypeAttr(), "tile")))
3458 return failure();
3459
3460 // auto, independent and seq attribute are mutually exclusive.
3461 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3462 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3463 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3464 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3465 return emitError() << "only one of auto, independent, seq can be present "
3466 "at the same time";
3467 }
3468
3469 // Check that at least one of auto, independent, or seq is present
3470 // for the device-independent default clauses.
3471 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3472 return attr.getValue() == mlir::acc::DeviceType::None;
3473 };
3474 bool hasDefaultSeq =
3475 getSeqAttr()
3476 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3477 hasDeviceNone)
3478 : false;
3479 bool hasDefaultIndependent =
3480 getIndependentAttr()
3481 ? llvm::any_of(
3482 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3483 hasDeviceNone)
3484 : false;
3485 bool hasDefaultAuto =
3486 getAuto_Attr()
3487 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3488 hasDeviceNone)
3489 : false;
3490 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3491 return emitError()
3492 << "at least one of auto, independent, seq must be present";
3493 }
3494
3495 // Gang, worker and vector are incompatible with seq.
3496 if (getSeqAttr()) {
3497 for (auto attr : getSeqAttr()) {
3498 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3499 if (hasVector(deviceTypeAttr.getValue()) ||
3500 getVectorValue(deviceTypeAttr.getValue()) ||
3501 hasWorker(deviceTypeAttr.getValue()) ||
3502 getWorkerValue(deviceTypeAttr.getValue()) ||
3503 hasGang(deviceTypeAttr.getValue()) ||
3504 getGangValue(mlir::acc::GangArgType::Num,
3505 deviceTypeAttr.getValue()) ||
3506 getGangValue(mlir::acc::GangArgType::Dim,
3507 deviceTypeAttr.getValue()) ||
3508 getGangValue(mlir::acc::GangArgType::Static,
3509 deviceTypeAttr.getValue()))
3510 return emitError() << "gang, worker or vector cannot appear with seq";
3511 }
3512 }
3513
3514 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3515 mlir::acc::PrivateRecipeOp>(
3516 *this, getPrivateOperands(), "private")))
3517 return failure();
3518
3519 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3520 mlir::acc::FirstprivateRecipeOp>(
3521 *this, getFirstprivateOperands(), "firstprivate")))
3522 return failure();
3523
3524 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3525 mlir::acc::ReductionRecipeOp>(
3526 *this, getReductionOperands(), "reduction")))
3527 return failure();
3528
3529 if (getCombined().has_value() &&
3530 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3531 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3532 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3533 return emitError("unexpected combined constructs attribute");
3534 }
3535
3536 // Check non-empty body().
3537 if (getRegion().empty())
3538 return emitError("expected non-empty body.");
3539
3540 if (getUnstructured()) {
3541 if (!isContainerLike())
3542 return emitError(
3543 "unstructured acc.loop must not have induction variables");
3544 } else if (isContainerLike()) {
3545 // When it is container-like - it is expected to hold a loop-like operation.
3546 // Obtain the maximum collapse count - we use this to check that there
3547 // are enough loops contained.
3548 uint64_t collapseCount = getCollapseValue().value_or(1);
3549 if (getCollapseAttr()) {
3550 for (auto collapseEntry : getCollapseAttr()) {
3551 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3552 if (intAttr.getValue().getZExtValue() > collapseCount)
3553 collapseCount = intAttr.getValue().getZExtValue();
3554 }
3555 }
3556
3557 // We want to check that we find enough loop-like operations inside.
3558 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3559 // level.
3560 mlir::Operation *expectedParent = this->getOperation();
3561 bool foundSibling = false;
3562 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3563 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3564 // This effectively checks that we are not looking at a sibling loop.
3565 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3566 expectedParent) {
3567 foundSibling = true;
3569 }
3570
3571 collapseCount--;
3572 expectedParent = op;
3573 }
3574 // We found enough contained loops.
3575 if (collapseCount == 0)
3578 });
3579
3580 if (foundSibling)
3581 return emitError("found sibling loops inside container-like acc.loop");
3582 if (collapseCount != 0)
3583 return emitError("failed to find enough loop-like operations inside "
3584 "container-like acc.loop");
3585 }
3586
3587 return success();
3588}
3589
3590unsigned LoopOp::getNumDataOperands() {
3591 return getReductionOperands().size() + getPrivateOperands().size() +
3592 getFirstprivateOperands().size();
3593}
3594
3595Value LoopOp::getDataOperand(unsigned i) {
3596 unsigned numOptional =
3597 getLowerbound().size() + getUpperbound().size() + getStep().size();
3598 numOptional += getGangOperands().size();
3599 numOptional += getVectorOperands().size();
3600 numOptional += getWorkerNumOperands().size();
3601 numOptional += getTileOperands().size();
3602 numOptional += getCacheOperands().size();
3603 return getOperand(numOptional + i);
3604}
3605
3606bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3607
3608bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3609 return hasDeviceType(getAuto_(), deviceType);
3610}
3611
3612bool LoopOp::hasIndependent() {
3613 return hasIndependent(mlir::acc::DeviceType::None);
3614}
3615
3616bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3617 return hasDeviceType(getIndependent(), deviceType);
3618}
3619
3620bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3621
3622bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3623 return hasDeviceType(getSeq(), deviceType);
3624}
3625
3626mlir::Value LoopOp::getVectorValue() {
3627 return getVectorValue(mlir::acc::DeviceType::None);
3628}
3629
3630mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3631 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3632 getVectorOperands(), deviceType);
3633}
3634
3635bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3636
3637bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3638 return hasDeviceType(getVector(), deviceType);
3639}
3640
3641mlir::Value LoopOp::getWorkerValue() {
3642 return getWorkerValue(mlir::acc::DeviceType::None);
3643}
3644
3645mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3646 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3647 getWorkerNumOperands(), deviceType);
3648}
3649
3650bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3651
3652bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3653 return hasDeviceType(getWorker(), deviceType);
3654}
3655
3656mlir::Operation::operand_range LoopOp::getTileValues() {
3657 return getTileValues(mlir::acc::DeviceType::None);
3658}
3659
3661LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3662 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3663 getTileOperandsSegments(), deviceType);
3664}
3665
3666std::optional<int64_t> LoopOp::getCollapseValue() {
3667 return getCollapseValue(mlir::acc::DeviceType::None);
3668}
3669
3670std::optional<int64_t>
3671LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3672 if (!getCollapseAttr())
3673 return std::nullopt;
3674 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3675 auto intAttr =
3676 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3677 return intAttr.getValue().getZExtValue();
3678 }
3679 return std::nullopt;
3680}
3681
3682mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3683 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3684}
3685
3686mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3687 mlir::acc::DeviceType deviceType) {
3688 if (getGangOperands().empty())
3689 return {};
3690 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3691 int32_t nbOperandsBefore = 0;
3692 for (unsigned i = 0; i < *pos; ++i)
3693 nbOperandsBefore += (*getGangOperandsSegments())[i];
3695 getGangOperands()
3696 .drop_front(nbOperandsBefore)
3697 .take_front((*getGangOperandsSegments())[*pos]);
3698
3699 int32_t argTypeIdx = nbOperandsBefore;
3700 for (auto value : values) {
3701 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3702 (*getGangOperandsArgType())[argTypeIdx]);
3703 if (gangArgTypeAttr.getValue() == gangArgType)
3704 return value;
3705 ++argTypeIdx;
3706 }
3707 }
3708 return {};
3709}
3710
3711bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3712
3713bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3714 return hasDeviceType(getGang(), deviceType);
3715}
3716
3717llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3718 return {&getRegion()};
3719}
3720
3721/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3722/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3723/// `(` ssa-id-and-type-list `)`
3724/// region
3725ParseResult
3728 SmallVectorImpl<Type> &lowerboundType,
3730 SmallVectorImpl<Type> &upperboundType,
3732 SmallVectorImpl<Type> &stepType) {
3733
3735 if (succeeded(
3736 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3737 if (parser.parseLParen() ||
3738 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3739 /*allowType=*/true) ||
3740 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3741 parser.parseOperandList(lowerbound, inductionVars.size(),
3743 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3744 parser.parseKeyword("to") || parser.parseLParen() ||
3745 parser.parseOperandList(upperbound, inductionVars.size(),
3747 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3748 parser.parseKeyword("step") || parser.parseLParen() ||
3749 parser.parseOperandList(step, inductionVars.size(),
3751 parser.parseColonTypeList(stepType) || parser.parseRParen())
3752 return failure();
3753 }
3754 return parser.parseRegion(region, inductionVars);
3755}
3756
3758 ValueRange lowerbound, TypeRange lowerboundType,
3759 ValueRange upperbound, TypeRange upperboundType,
3760 ValueRange steps, TypeRange stepType) {
3761 ValueRange regionArgs = region.front().getArguments();
3762 if (!regionArgs.empty()) {
3763 p << acc::LoopOp::getControlKeyword() << "(";
3764 llvm::interleaveComma(regionArgs, p,
3765 [&p](Value v) { p << v << " : " << v.getType(); });
3766 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3767 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3768 << " : " << stepType << ") ";
3769 }
3770 p.printRegion(region, /*printEntryBlockArgs=*/false);
3771}
3772
3773void acc::LoopOp::addSeq(MLIRContext *context,
3774 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3775 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3776 effectiveDeviceTypes));
3777}
3778
3779void acc::LoopOp::addIndependent(
3780 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3781 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3782 context, getIndependentAttr(), effectiveDeviceTypes));
3783}
3784
3785void acc::LoopOp::addAuto(MLIRContext *context,
3786 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3787 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3788 effectiveDeviceTypes));
3789}
3790
3791void acc::LoopOp::setCollapseForDeviceTypes(
3792 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3793 llvm::APInt value) {
3796
3797 assert((getCollapseAttr() == nullptr) ==
3798 (getCollapseDeviceTypeAttr() == nullptr));
3799 assert(value.getBitWidth() == 64);
3800
3801 if (getCollapseAttr()) {
3802 for (const auto &existing :
3803 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3804 newValues.push_back(std::get<0>(existing));
3805 newDeviceTypes.push_back(std::get<1>(existing));
3806 }
3807 }
3808
3809 if (effectiveDeviceTypes.empty()) {
3810 // If the effective device-types list is empty, this is before there are any
3811 // being applied by device_type, so this should be added as a 'none'.
3812 newValues.push_back(
3813 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3814 newDeviceTypes.push_back(
3815 acc::DeviceTypeAttr::get(context, DeviceType::None));
3816 } else {
3817 for (DeviceType dt : effectiveDeviceTypes) {
3818 newValues.push_back(
3819 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3820 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3821 }
3822 }
3823
3824 setCollapseAttr(ArrayAttr::get(context, newValues));
3825 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3826}
3827
3828void acc::LoopOp::setTileForDeviceTypes(
3829 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3830 ValueRange values) {
3832 if (getTileOperandsSegments())
3833 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3834
3835 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3836 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3837 getTileOperandsMutable(), segments));
3838
3839 setTileOperandsSegments(segments);
3840}
3841
3842void acc::LoopOp::addVectorOperand(
3843 MLIRContext *context, mlir::Value newValue,
3844 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3845 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3846 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3847 newValue, getVectorOperandsMutable()));
3848}
3849
3850void acc::LoopOp::addEmptyVector(
3851 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3852 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3853 effectiveDeviceTypes));
3854}
3855
3856void acc::LoopOp::addWorkerNumOperand(
3857 MLIRContext *context, mlir::Value newValue,
3858 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3859 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3860 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3861 newValue, getWorkerNumOperandsMutable()));
3862}
3863
3864void acc::LoopOp::addEmptyWorker(
3865 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3866 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3867 effectiveDeviceTypes));
3868}
3869
3870void acc::LoopOp::addEmptyGang(
3871 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3872 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3873 effectiveDeviceTypes));
3874}
3875
3876bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3877 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3878 return attr.getValue() == dt;
3879 };
3880 auto testFromArr = [=](ArrayAttr arr) -> bool {
3881 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3882 };
3883
3884 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3885 return true;
3886 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3887 return true;
3888 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3889 return true;
3890
3891 return false;
3892}
3893
3894bool acc::LoopOp::hasDefaultGangWorkerVector() {
3895 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3896 hasGang() || getGangValue(GangArgType::Num) ||
3897 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3898}
3899
3900acc::LoopParMode
3901acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3902 if (hasSeq(deviceType))
3903 return LoopParMode::loop_seq;
3904 if (hasAuto(deviceType))
3905 return LoopParMode::loop_auto;
3906 if (hasIndependent(deviceType))
3907 return LoopParMode::loop_independent;
3908 if (hasSeq())
3909 return LoopParMode::loop_seq;
3910 if (hasAuto())
3911 return LoopParMode::loop_auto;
3912 assert(hasIndependent() &&
3913 "loop must have default auto, seq, or independent");
3914 return LoopParMode::loop_independent;
3915}
3916
3917void acc::LoopOp::addGangOperands(
3918 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3921 if (std::optional<ArrayRef<int32_t>> existingSegments =
3922 getGangOperandsSegments())
3923 llvm::copy(*existingSegments, std::back_inserter(segments));
3924
3925 unsigned beforeCount = segments.size();
3926
3927 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3928 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3929 getGangOperandsMutable(), segments));
3930
3931 setGangOperandsSegments(segments);
3932
3933 // This is a bit of extra work to make sure we update the 'types' correctly by
3934 // adding to the types collection the correct number of times. We could
3935 // potentially add something similar to the
3936 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3937 // excessive for a one-off case.
3938 unsigned numAdded = segments.size() - beforeCount;
3939
3940 if (numAdded > 0) {
3942 if (getGangOperandsArgTypeAttr())
3943 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3944
3945 for (auto i : llvm::index_range(0u, numAdded)) {
3946 llvm::transform(argTypes, std::back_inserter(gangTypes),
3947 [=](mlir::acc::GangArgType gangTy) {
3948 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3949 });
3950 (void)i;
3951 }
3952
3953 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3954 }
3955}
3956
3957void acc::LoopOp::addPrivatization(MLIRContext *context,
3958 mlir::acc::PrivateOp op,
3959 mlir::acc::PrivateRecipeOp recipe) {
3960 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3961 getPrivateOperandsMutable().append(op.getResult());
3962}
3963
3964void acc::LoopOp::addFirstPrivatization(
3965 MLIRContext *context, mlir::acc::FirstprivateOp op,
3966 mlir::acc::FirstprivateRecipeOp recipe) {
3967 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3968 getFirstprivateOperandsMutable().append(op.getResult());
3969}
3970
3971void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3972 mlir::acc::ReductionRecipeOp recipe) {
3973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3974 getReductionOperandsMutable().append(op.getResult());
3975}
3976
3977//===----------------------------------------------------------------------===//
3978// DataOp
3979//===----------------------------------------------------------------------===//
3980
3981LogicalResult acc::DataOp::verify() {
3982 // 2.6.5. Data Construct restriction
3983 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3984 // attach, or default clause must appear on a data construct.
3985 if (getOperands().empty() && !getDefaultAttr())
3986 return emitError("at least one operand or the default attribute "
3987 "must appear on the data operation");
3988
3989 for (mlir::Value operand : getDataClauseOperands())
3990 if (isa<BlockArgument>(operand) ||
3991 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3992 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3993 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3994 operand.getDefiningOp()))
3995 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3996 "as defining op");
3997
3999 return failure();
4000
4001 return success();
4002}
4003
4004unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
4005
4006Value DataOp::getDataOperand(unsigned i) {
4007 unsigned numOptional = getIfCond() ? 1 : 0;
4008 numOptional += getAsyncOperands().size() ? 1 : 0;
4009 numOptional += getWaitOperands().size();
4010 return getOperand(numOptional + i);
4011}
4012
4013bool acc::DataOp::hasAsyncOnly() {
4014 return hasAsyncOnly(mlir::acc::DeviceType::None);
4015}
4016
4017bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4018 return hasDeviceType(getAsyncOnly(), deviceType);
4019}
4020
4021mlir::Value DataOp::getAsyncValue() {
4022 return getAsyncValue(mlir::acc::DeviceType::None);
4023}
4024
4025mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4027 getAsyncOperands(), deviceType);
4028}
4029
4030bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4031
4032bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4033 return hasDeviceType(getWaitOnly(), deviceType);
4034}
4035
4036mlir::Operation::operand_range DataOp::getWaitValues() {
4037 return getWaitValues(mlir::acc::DeviceType::None);
4038}
4039
4041DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4043 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4044 getHasWaitDevnum(), deviceType);
4045}
4046
4047mlir::Value DataOp::getWaitDevnum() {
4048 return getWaitDevnum(mlir::acc::DeviceType::None);
4049}
4050
4051mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4052 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4053 getWaitOperandsSegments(), getHasWaitDevnum(),
4054 deviceType);
4055}
4056
4057void acc::DataOp::addAsyncOnly(
4058 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4059 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4060 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4061}
4062
4063void acc::DataOp::addAsyncOperand(
4064 MLIRContext *context, mlir::Value newValue,
4065 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4066 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4067 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4068 getAsyncOperandsMutable()));
4069}
4070
4071void acc::DataOp::addWaitOnly(MLIRContext *context,
4072 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4073 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4074 effectiveDeviceTypes));
4075}
4076
4077void acc::DataOp::addWaitOperands(
4078 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4079 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4080
4082 if (getWaitOperandsSegments())
4083 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4084
4085 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4086 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4087 getWaitOperandsMutable(), segments));
4088 setWaitOperandsSegments(segments);
4089
4091 if (getHasWaitDevnumAttr())
4092 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4093 hasDevnums.insert(
4094 hasDevnums.end(),
4095 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4096 mlir::BoolAttr::get(context, hasDevnum));
4097 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4098}
4099
4100//===----------------------------------------------------------------------===//
4101// ExitDataOp
4102//===----------------------------------------------------------------------===//
4103
4104LogicalResult acc::ExitDataOp::verify() {
4105 // 2.6.6. Data Exit Directive restriction
4106 // At least one copyout, delete, or detach clause must appear on an exit data
4107 // directive.
4108 if (getDataClauseOperands().empty())
4109 return emitError("at least one operand must be present in dataOperands on "
4110 "the exit data operation");
4111
4112 // The async attribute represent the async clause without value. Therefore the
4113 // attribute and operand cannot appear at the same time.
4114 if (getAsyncOperand() && getAsync())
4115 return emitError("async attribute cannot appear with asyncOperand");
4116
4117 // The wait attribute represent the wait clause without values. Therefore the
4118 // attribute and operands cannot appear at the same time.
4119 if (!getWaitOperands().empty() && getWait())
4120 return emitError("wait attribute cannot appear with waitOperands");
4121
4122 if (getWaitDevnum() && getWaitOperands().empty())
4123 return emitError("wait_devnum cannot appear without waitOperands");
4124
4125 return success();
4126}
4127
4128unsigned ExitDataOp::getNumDataOperands() {
4129 return getDataClauseOperands().size();
4130}
4131
4132Value ExitDataOp::getDataOperand(unsigned i) {
4133 unsigned numOptional = getIfCond() ? 1 : 0;
4134 numOptional += getAsyncOperand() ? 1 : 0;
4135 numOptional += getWaitDevnum() ? 1 : 0;
4136 return getOperand(getWaitOperands().size() + numOptional + i);
4137}
4138
4139void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4140 MLIRContext *context) {
4141 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4142}
4143
4144void ExitDataOp::addAsyncOnly(MLIRContext *context,
4145 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4146 assert(effectiveDeviceTypes.empty());
4147 assert(!getAsyncAttr());
4148 assert(!getAsyncOperand());
4149
4150 setAsyncAttr(mlir::UnitAttr::get(context));
4151}
4152
4153void ExitDataOp::addAsyncOperand(
4154 MLIRContext *context, mlir::Value newValue,
4155 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4156 assert(effectiveDeviceTypes.empty());
4157 assert(!getAsyncAttr());
4158 assert(!getAsyncOperand());
4159
4160 getAsyncOperandMutable().append(newValue);
4161}
4162
4163void ExitDataOp::addWaitOnly(MLIRContext *context,
4164 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4165 assert(effectiveDeviceTypes.empty());
4166 assert(!getWaitAttr());
4167 assert(getWaitOperands().empty());
4168 assert(!getWaitDevnum());
4169
4170 setWaitAttr(mlir::UnitAttr::get(context));
4171}
4172
4173void ExitDataOp::addWaitOperands(
4174 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4175 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4176 assert(effectiveDeviceTypes.empty());
4177 assert(!getWaitAttr());
4178 assert(getWaitOperands().empty());
4179 assert(!getWaitDevnum());
4180
4181 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4182 // operands list.
4183 if (hasDevnum) {
4184 getWaitDevnumMutable().append(newValues.front());
4185 newValues = newValues.drop_front();
4186 }
4187
4188 getWaitOperandsMutable().append(newValues);
4189}
4190
4191//===----------------------------------------------------------------------===//
4192// EnterDataOp
4193//===----------------------------------------------------------------------===//
4194
4195LogicalResult acc::EnterDataOp::verify() {
4196 // 2.6.6. Data Enter Directive restriction
4197 // At least one copyin, create, or attach clause must appear on an enter data
4198 // directive.
4199 if (getDataClauseOperands().empty())
4200 return emitError("at least one operand must be present in dataOperands on "
4201 "the enter data operation");
4202
4203 // The async attribute represent the async clause without value. Therefore the
4204 // attribute and operand cannot appear at the same time.
4205 if (getAsyncOperand() && getAsync())
4206 return emitError("async attribute cannot appear with asyncOperand");
4207
4208 // The wait attribute represent the wait clause without values. Therefore the
4209 // attribute and operands cannot appear at the same time.
4210 if (!getWaitOperands().empty() && getWait())
4211 return emitError("wait attribute cannot appear with waitOperands");
4212
4213 if (getWaitDevnum() && getWaitOperands().empty())
4214 return emitError("wait_devnum cannot appear without waitOperands");
4215
4216 for (mlir::Value operand : getDataClauseOperands())
4217 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4218 operand.getDefiningOp()))
4219 return emitError("expect data entry operation as defining op");
4220
4221 return success();
4222}
4223
4224unsigned EnterDataOp::getNumDataOperands() {
4225 return getDataClauseOperands().size();
4226}
4227
4228Value EnterDataOp::getDataOperand(unsigned i) {
4229 unsigned numOptional = getIfCond() ? 1 : 0;
4230 numOptional += getAsyncOperand() ? 1 : 0;
4231 numOptional += getWaitDevnum() ? 1 : 0;
4232 return getOperand(getWaitOperands().size() + numOptional + i);
4233}
4234
4235void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4236 MLIRContext *context) {
4237 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4238}
4239
4240void EnterDataOp::addAsyncOnly(
4241 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getAsyncAttr());
4244 assert(!getAsyncOperand());
4245
4246 setAsyncAttr(mlir::UnitAttr::get(context));
4247}
4248
4249void EnterDataOp::addAsyncOperand(
4250 MLIRContext *context, mlir::Value newValue,
4251 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4252 assert(effectiveDeviceTypes.empty());
4253 assert(!getAsyncAttr());
4254 assert(!getAsyncOperand());
4255
4256 getAsyncOperandMutable().append(newValue);
4257}
4258
4259void EnterDataOp::addWaitOnly(MLIRContext *context,
4260 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4261 assert(effectiveDeviceTypes.empty());
4262 assert(!getWaitAttr());
4263 assert(getWaitOperands().empty());
4264 assert(!getWaitDevnum());
4265
4266 setWaitAttr(mlir::UnitAttr::get(context));
4267}
4268
4269void EnterDataOp::addWaitOperands(
4270 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4271 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4272 assert(effectiveDeviceTypes.empty());
4273 assert(!getWaitAttr());
4274 assert(getWaitOperands().empty());
4275 assert(!getWaitDevnum());
4276
4277 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4278 // operands list.
4279 if (hasDevnum) {
4280 getWaitDevnumMutable().append(newValues.front());
4281 newValues = newValues.drop_front();
4282 }
4283
4284 getWaitOperandsMutable().append(newValues);
4285}
4286
4287//===----------------------------------------------------------------------===//
4288// AtomicReadOp
4289//===----------------------------------------------------------------------===//
4290
4291LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4292
4293//===----------------------------------------------------------------------===//
4294// AtomicWriteOp
4295//===----------------------------------------------------------------------===//
4296
4297LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4298
4299//===----------------------------------------------------------------------===//
4300// AtomicUpdateOp
4301//===----------------------------------------------------------------------===//
4302
4303LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4304 PatternRewriter &rewriter) {
4305 if (op.isNoOp()) {
4306 rewriter.eraseOp(op);
4307 return success();
4308 }
4309
4310 if (Value writeVal = op.getWriteOpVal()) {
4311 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4312 op.getIfCond());
4313 return success();
4314 }
4315
4316 return failure();
4317}
4318
4319LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4320
4321LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4322
4323//===----------------------------------------------------------------------===//
4324// AtomicCaptureOp
4325//===----------------------------------------------------------------------===//
4326
4327AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4328 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4329 return op;
4330 return dyn_cast<AtomicReadOp>(getSecondOp());
4331}
4332
4333AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4334 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4335 return op;
4336 return dyn_cast<AtomicWriteOp>(getSecondOp());
4337}
4338
4339AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4340 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4341 return op;
4342 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4343}
4344
4345LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4346
4347//===----------------------------------------------------------------------===//
4348// DeclareEnterOp
4349//===----------------------------------------------------------------------===//
4350
4351template <typename Op>
4352static LogicalResult
4354 bool requireAtLeastOneOperand = true) {
4355 if (operands.empty() && requireAtLeastOneOperand)
4356 return emitError(
4357 op->getLoc(),
4358 "at least one operand must appear on the declare operation");
4359
4360 for (mlir::Value operand : operands) {
4361 if (isa<BlockArgument>(operand) ||
4362 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4363 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4364 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4365 operand.getDefiningOp()))
4366 return op.emitError(
4367 "expect valid declare data entry operation or acc.getdeviceptr "
4368 "as defining op");
4369
4370 mlir::Value var{getVar(operand.getDefiningOp())};
4371 assert(var && "declare operands can only be data entry operations which "
4372 "must have var");
4373 (void)var;
4374 std::optional<mlir::acc::DataClause> dataClauseOptional{
4375 getDataClause(operand.getDefiningOp())};
4376 assert(dataClauseOptional.has_value() &&
4377 "declare operands can only be data entry operations which must have "
4378 "dataClause");
4379 (void)dataClauseOptional;
4380 }
4381
4382 return success();
4383}
4384
4385LogicalResult acc::DeclareEnterOp::verify() {
4386 return checkDeclareOperands(*this, this->getDataClauseOperands());
4387}
4388
4389//===----------------------------------------------------------------------===//
4390// DeclareExitOp
4391//===----------------------------------------------------------------------===//
4392
4393LogicalResult acc::DeclareExitOp::verify() {
4394 if (getToken())
4395 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4396 /*requireAtLeastOneOperand=*/false);
4397 return checkDeclareOperands(*this, this->getDataClauseOperands());
4398}
4399
4400//===----------------------------------------------------------------------===//
4401// DeclareOp
4402//===----------------------------------------------------------------------===//
4403
4404LogicalResult acc::DeclareOp::verify() {
4405 return checkDeclareOperands(*this, this->getDataClauseOperands());
4406}
4407
4408//===----------------------------------------------------------------------===//
4409// RoutineOp
4410//===----------------------------------------------------------------------===//
4411
4412static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4413 acc::DeviceType dtype) {
4414 unsigned parallelism = 0;
4415 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4416 parallelism += op.hasWorker(dtype) ? 1 : 0;
4417 parallelism += op.hasVector(dtype) ? 1 : 0;
4418 parallelism += op.hasSeq(dtype) ? 1 : 0;
4419 return parallelism;
4420}
4421
4422LogicalResult acc::RoutineOp::verify() {
4423 unsigned baseParallelism =
4424 getParallelismForDeviceType(*this, acc::DeviceType::None);
4425
4426 if (baseParallelism > 1)
4427 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4428 "be present at the same time";
4429
4430 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4431 ++dtypeInt) {
4432 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4433 if (dtype == acc::DeviceType::None)
4434 continue;
4435 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4436
4437 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4438 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4439 "be present at the same time for device_type `"
4440 << acc::stringifyDeviceType(dtype) << "`";
4441 }
4442
4443 return success();
4444}
4445
4446static ParseResult parseBindName(OpAsmParser &parser,
4447 mlir::ArrayAttr &bindIdName,
4448 mlir::ArrayAttr &bindStrName,
4449 mlir::ArrayAttr &deviceIdTypes,
4450 mlir::ArrayAttr &deviceStrTypes) {
4451 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4452 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4453 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4454 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4455
4456 if (failed(parser.parseCommaSeparatedList([&]() {
4457 mlir::Attribute newAttr;
4458 bool isSymbolRefAttr;
4459 auto parseResult = parser.parseAttribute(newAttr);
4460 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4461 bindIdNameAttrs.push_back(symbolRefAttr);
4462 isSymbolRefAttr = true;
4463 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4464 bindStrNameAttrs.push_back(stringAttr);
4465 isSymbolRefAttr = false;
4466 }
4467 if (parseResult)
4468 return failure();
4469 if (failed(parser.parseOptionalLSquare())) {
4470 if (isSymbolRefAttr) {
4471 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4472 parser.getContext(), mlir::acc::DeviceType::None));
4473 } else {
4474 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4475 parser.getContext(), mlir::acc::DeviceType::None));
4476 }
4477 } else {
4478 if (isSymbolRefAttr) {
4479 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4480 parser.parseRSquare())
4481 return failure();
4482 } else {
4483 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4484 parser.parseRSquare())
4485 return failure();
4486 }
4487 }
4488 return success();
4489 })))
4490 return failure();
4491
4492 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4493 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4494 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4495 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4496
4497 return success();
4498}
4499
4501 std::optional<mlir::ArrayAttr> bindIdName,
4502 std::optional<mlir::ArrayAttr> bindStrName,
4503 std::optional<mlir::ArrayAttr> deviceIdTypes,
4504 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4505 // Create combined vectors for all bind names and device types
4508
4509 // Append bindIdName and deviceIdTypes
4510 if (hasDeviceTypeValues(deviceIdTypes)) {
4511 allBindNames.append(bindIdName->begin(), bindIdName->end());
4512 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4513 }
4514
4515 // Append bindStrName and deviceStrTypes
4516 if (hasDeviceTypeValues(deviceStrTypes)) {
4517 allBindNames.append(bindStrName->begin(), bindStrName->end());
4518 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4519 }
4520
4521 // Print the combined sequence
4522 if (!allBindNames.empty())
4523 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4524 [&](const auto &pair) {
4525 p << std::get<0>(pair);
4526 printSingleDeviceType(p, std::get<1>(pair));
4527 });
4528}
4529
4530static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4531 mlir::ArrayAttr &gang,
4532 mlir::ArrayAttr &gangDim,
4533 mlir::ArrayAttr &gangDimDeviceTypes) {
4534
4535 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4536 gangDimDeviceTypeAttrs;
4537 bool needCommaBeforeOperands = false;
4538
4539 // Gang keyword only
4540 if (failed(parser.parseOptionalLParen())) {
4541 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4542 parser.getContext(), mlir::acc::DeviceType::None));
4543 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4544 return success();
4545 }
4546
4547 // Parse keyword only attributes
4548 if (succeeded(parser.parseOptionalLSquare())) {
4549 if (failed(parser.parseCommaSeparatedList([&]() {
4550 if (parser.parseAttribute(gangAttrs.emplace_back()))
4551 return failure();
4552 return success();
4553 })))
4554 return failure();
4555 if (parser.parseRSquare())
4556 return failure();
4557 needCommaBeforeOperands = true;
4558 }
4559
4560 if (needCommaBeforeOperands && failed(parser.parseComma()))
4561 return failure();
4562
4563 if (failed(parser.parseCommaSeparatedList([&]() {
4564 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4565 parser.parseColon() ||
4566 parser.parseAttribute(gangDimAttrs.emplace_back()))
4567 return failure();
4568 if (succeeded(parser.parseOptionalLSquare())) {
4569 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4570 parser.parseRSquare())
4571 return failure();
4572 } else {
4573 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4574 parser.getContext(), mlir::acc::DeviceType::None));
4575 }
4576 return success();
4577 })))
4578 return failure();
4579
4580 if (failed(parser.parseRParen()))
4581 return failure();
4582
4583 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4584 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4585 gangDimDeviceTypes =
4586 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4587
4588 return success();
4589}
4590
4592 std::optional<mlir::ArrayAttr> gang,
4593 std::optional<mlir::ArrayAttr> gangDim,
4594 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4595
4596 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4597 gang->size() == 1) {
4598 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4599 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4600 return;
4601 }
4602
4603 p << "(";
4604
4605 printDeviceTypes(p, gang);
4606
4607 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4608 p << ", ";
4609
4610 if (hasDeviceTypeValues(gangDimDeviceTypes))
4611 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4612 [&](const auto &pair) {
4613 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4614 p << std::get<0>(pair);
4615 printSingleDeviceType(p, std::get<1>(pair));
4616 });
4617
4618 p << ")";
4619}
4620
4621static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4622 mlir::ArrayAttr &deviceTypes) {
4624 // Keyword only
4625 if (failed(parser.parseOptionalLParen())) {
4626 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4627 parser.getContext(), mlir::acc::DeviceType::None));
4628 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4629 return success();
4630 }
4631
4632 // Parse device type attributes
4633 if (succeeded(parser.parseOptionalLSquare())) {
4634 if (failed(parser.parseCommaSeparatedList([&]() {
4635 if (parser.parseAttribute(attributes.emplace_back()))
4636 return failure();
4637 return success();
4638 })))
4639 return failure();
4640 if (parser.parseRSquare() || parser.parseRParen())
4641 return failure();
4642 }
4643 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4644 return success();
4645}
4646
4647static void
4649 std::optional<mlir::ArrayAttr> deviceTypes) {
4650
4651 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4652 auto deviceTypeAttr =
4653 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4654 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4655 return;
4656 }
4657
4658 if (!hasDeviceTypeValues(deviceTypes))
4659 return;
4660
4661 p << "([";
4662 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4663 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4664 p << dTypeAttr;
4665 });
4666 p << "])";
4667}
4668
4669bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4670
4671bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4672 return hasDeviceType(getWorker(), deviceType);
4673}
4674
4675bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4676
4677bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4678 return hasDeviceType(getVector(), deviceType);
4679}
4680
4681bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4682
4683bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4684 return hasDeviceType(getSeq(), deviceType);
4685}
4686
4687std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4688RoutineOp::getBindNameValue() {
4689 return getBindNameValue(mlir::acc::DeviceType::None);
4690}
4691
4692std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4693RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4694 if (hasDeviceTypeValues(getBindIdNameDeviceType())) {
4695 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4696 auto attr = (*getBindIdName())[*pos];
4697 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4698 assert(symbolRefAttr && "expected SymbolRef");
4699 return symbolRefAttr;
4700 }
4701 }
4702
4703 if (hasDeviceTypeValues(getBindStrNameDeviceType())) {
4704 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4705 auto attr = (*getBindStrName())[*pos];
4706 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4707 assert(stringAttr && "expected String");
4708 return stringAttr;
4709 }
4710 }
4711
4712 return std::nullopt;
4713}
4714
4715bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4716
4717bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4718 return hasDeviceType(getGang(), deviceType);
4719}
4720
4721std::optional<int64_t> RoutineOp::getGangDimValue() {
4722 return getGangDimValue(mlir::acc::DeviceType::None);
4723}
4724
4725std::optional<int64_t>
4726RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4727 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4728 return std::nullopt;
4729 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4730 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4731 return intAttr.getInt();
4732 }
4733 return std::nullopt;
4734}
4735
4736void RoutineOp::addSeq(MLIRContext *context,
4737 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4738 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4739 effectiveDeviceTypes));
4740}
4741
4742void RoutineOp::addVector(MLIRContext *context,
4743 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4744 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4745 effectiveDeviceTypes));
4746}
4747
4748void RoutineOp::addWorker(MLIRContext *context,
4749 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4750 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4751 effectiveDeviceTypes));
4752}
4753
4754void RoutineOp::addGang(MLIRContext *context,
4755 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4756 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4757 effectiveDeviceTypes));
4758}
4759
4760void RoutineOp::addGang(MLIRContext *context,
4761 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4762 uint64_t val) {
4765
4766 if (getGangDimAttr())
4767 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4768 if (getGangDimDeviceTypeAttr())
4769 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4770
4771 assert(dimValues.size() == deviceTypes.size());
4772
4773 if (effectiveDeviceTypes.empty()) {
4774 dimValues.push_back(
4775 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4776 deviceTypes.push_back(
4777 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4778 } else {
4779 for (DeviceType dt : effectiveDeviceTypes) {
4780 dimValues.push_back(
4781 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4782 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4783 }
4784 }
4785 assert(dimValues.size() == deviceTypes.size());
4786
4787 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4788 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4789}
4790
4791void RoutineOp::addBindStrName(MLIRContext *context,
4792 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4793 mlir::StringAttr val) {
4794 unsigned before = getBindStrNameDeviceTypeAttr()
4795 ? getBindStrNameDeviceTypeAttr().size()
4796 : 0;
4797
4798 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4799 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4800 unsigned after = getBindStrNameDeviceTypeAttr().size();
4801
4803 if (getBindStrNameAttr())
4804 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4805 for (unsigned i = 0; i < after - before; ++i)
4806 vals.push_back(val);
4807
4808 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4809}
4810
4811void RoutineOp::addBindIDName(MLIRContext *context,
4812 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4813 mlir::SymbolRefAttr val) {
4814 unsigned before =
4815 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4816
4817 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4818 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4819 unsigned after = getBindIdNameDeviceTypeAttr().size();
4820
4822 if (getBindIdNameAttr())
4823 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4824 for (unsigned i = 0; i < after - before; ++i)
4825 vals.push_back(val);
4826
4827 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4828}
4829
4830//===----------------------------------------------------------------------===//
4831// InitOp
4832//===----------------------------------------------------------------------===//
4833
4834LogicalResult acc::InitOp::verify() {
4835 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4836 return emitOpError("cannot be nested in a compute operation");
4837 return success();
4838}
4839
4840void acc::InitOp::addDeviceType(MLIRContext *context,
4841 mlir::acc::DeviceType deviceType) {
4843 if (getDeviceTypesAttr())
4844 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4845
4846 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4847 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4848}
4849
4850//===----------------------------------------------------------------------===//
4851// ShutdownOp
4852//===----------------------------------------------------------------------===//
4853
4854LogicalResult acc::ShutdownOp::verify() {
4855 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4856 return emitOpError("cannot be nested in a compute operation");
4857 return success();
4858}
4859
4860void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4861 mlir::acc::DeviceType deviceType) {
4863 if (getDeviceTypesAttr())
4864 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4865
4866 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4867 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4868}
4869
4870//===----------------------------------------------------------------------===//
4871// SetOp
4872//===----------------------------------------------------------------------===//
4873
4874LogicalResult acc::SetOp::verify() {
4875 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4876 return emitOpError("cannot be nested in a compute operation");
4877 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4878 return emitOpError("at least one default_async, device_num, or device_type "
4879 "operand must appear");
4880 return success();
4881}
4882
4883//===----------------------------------------------------------------------===//
4884// UpdateOp
4885//===----------------------------------------------------------------------===//
4886
4887LogicalResult acc::UpdateOp::verify() {
4888 // At least one of host or device should have a value.
4889 if (getDataClauseOperands().empty())
4890 return emitError("at least one value must be present in dataOperands");
4891
4893 getAsyncOperandsDeviceTypeAttr(),
4894 "async")))
4895 return failure();
4896
4898 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4899 getWaitOperandsDeviceTypeAttr(), "wait")))
4900 return failure();
4901
4903 return failure();
4904
4905 for (mlir::Value operand : getDataClauseOperands())
4906 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4907 operand.getDefiningOp()))
4908 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4909 "as defining op");
4910
4911 return success();
4912}
4913
4914unsigned UpdateOp::getNumDataOperands() {
4915 return getDataClauseOperands().size();
4916}
4917
4918Value UpdateOp::getDataOperand(unsigned i) {
4919 unsigned numOptional = getAsyncOperands().size();
4920 numOptional += getIfCond() ? 1 : 0;
4921 return getOperand(getWaitOperands().size() + numOptional + i);
4922}
4923
4924void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4925 MLIRContext *context) {
4926 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4927}
4928
4929bool UpdateOp::hasAsyncOnly() {
4930 return hasAsyncOnly(mlir::acc::DeviceType::None);
4931}
4932
4933bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4934 return hasDeviceType(getAsyncOnly(), deviceType);
4935}
4936
4937mlir::Value UpdateOp::getAsyncValue() {
4938 return getAsyncValue(mlir::acc::DeviceType::None);
4939}
4940
4941mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4943 return {};
4944
4945 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4946 return getAsyncOperands()[*pos];
4947
4948 return {};
4949}
4950
4951bool UpdateOp::hasWaitOnly() {
4952 return hasWaitOnly(mlir::acc::DeviceType::None);
4953}
4954
4955bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4956 return hasDeviceType(getWaitOnly(), deviceType);
4957}
4958
4959mlir::Operation::operand_range UpdateOp::getWaitValues() {
4960 return getWaitValues(mlir::acc::DeviceType::None);
4961}
4962
4964UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4966 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4967 getHasWaitDevnum(), deviceType);
4968}
4969
4970mlir::Value UpdateOp::getWaitDevnum() {
4971 return getWaitDevnum(mlir::acc::DeviceType::None);
4972}
4973
4974mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4975 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4976 getWaitOperandsSegments(), getHasWaitDevnum(),
4977 deviceType);
4978}
4979
4980void UpdateOp::addAsyncOnly(MLIRContext *context,
4981 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4982 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4983 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4984}
4985
4986void UpdateOp::addAsyncOperand(
4987 MLIRContext *context, mlir::Value newValue,
4988 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4989 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4990 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4991 getAsyncOperandsMutable()));
4992}
4993
4994void UpdateOp::addWaitOnly(MLIRContext *context,
4995 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4996 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4997 effectiveDeviceTypes));
4998}
4999
5000void UpdateOp::addWaitOperands(
5001 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
5002 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5003
5005 if (getWaitOperandsSegments())
5006 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5007
5008 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5009 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5010 getWaitOperandsMutable(), segments));
5011 setWaitOperandsSegments(segments);
5012
5014 if (getHasWaitDevnumAttr())
5015 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5016 hasDevnums.insert(
5017 hasDevnums.end(),
5018 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5019 mlir::BoolAttr::get(context, hasDevnum));
5020 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5021}
5022
5023//===----------------------------------------------------------------------===//
5024// WaitOp
5025//===----------------------------------------------------------------------===//
5026
5027LogicalResult acc::WaitOp::verify() {
5028 // The async attribute represent the async clause without value. Therefore the
5029 // attribute and operand cannot appear at the same time.
5030 if (getAsyncOperand() && getAsync())
5031 return emitError("async attribute cannot appear with asyncOperand");
5032
5033 if (getWaitDevnum() && getWaitOperands().empty())
5034 return emitError("wait_devnum cannot appear without waitOperands");
5035
5036 return success();
5037}
5038
5039#define GET_OP_CLASSES
5040#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5041
5042#define GET_ATTRDEF_CLASSES
5043#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5044
5045#define GET_TYPEDEF_CLASSES
5046#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5047
5048//===----------------------------------------------------------------------===//
5049// acc dialect utilities
5050//===----------------------------------------------------------------------===//
5051
5054 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5056 accDataClauseOp)
5057 .Case<ACC_DATA_ENTRY_OPS>(
5058 [&](auto entry) { return entry.getVarPtr(); })
5059 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5060 [&](auto exit) { return exit.getVarPtr(); })
5061 .Default([&](mlir::Operation *) {
5063 })};
5064 return varPtr;
5065}
5066
5068 auto varPtr{
5070 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5071 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5072 return varPtr;
5073}
5074
5076 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5077 .Case<ACC_DATA_ENTRY_OPS>(
5078 [&](auto entry) { return entry.getVarType(); })
5079 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5080 [&](auto exit) { return exit.getVarType(); })
5081 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5082 return varType;
5083}
5084
5087 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5089 accDataClauseOp)
5090 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5091 [&](auto dataClause) { return dataClause.getAccPtr(); })
5092 .Default([&](mlir::Operation *) {
5094 })};
5095 return accPtr;
5096}
5097
5099 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5101 [&](auto dataClause) { return dataClause.getAccVar(); })
5102 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5103 return accPtr;
5104}
5105
5107 auto varPtrPtr{
5109 .Case<ACC_DATA_ENTRY_OPS>(
5110 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5111 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5112 return varPtrPtr;
5113}
5114
5119 accDataClauseOp)
5120 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5122 dataClause.getBounds().begin(), dataClause.getBounds().end());
5123 })
5124 .Default([&](mlir::Operation *) {
5126 })};
5127 return bounds;
5128}
5129
5133 accDataClauseOp)
5134 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5136 dataClause.getAsyncOperands().begin(),
5137 dataClause.getAsyncOperands().end());
5138 })
5139 .Default([&](mlir::Operation *) {
5141 });
5142}
5143
5144mlir::ArrayAttr
5147 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5148 return dataClause.getAsyncOperandsDeviceTypeAttr();
5149 })
5150 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5151}
5152
5153mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5156 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5157 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5158}
5159
5160std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5161 auto name{
5163 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5164 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5165 return {};
5166 })};
5167 return name;
5168}
5169
5170std::optional<mlir::acc::DataClause>
5172 auto dataClause{
5174 accDataEntryOp)
5175 .Case<ACC_DATA_ENTRY_OPS>(
5176 [&](auto entry) { return entry.getDataClause(); })
5177 .Default([&](mlir::Operation *) { return std::nullopt; })};
5178 return dataClause;
5179}
5180
5182 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5183 .Case<ACC_DATA_ENTRY_OPS>(
5184 [&](auto entry) { return entry.getImplicit(); })
5185 .Default([&](mlir::Operation *) { return false; })};
5186 return implicit;
5187}
5188
5190 auto dataOperands{
5193 [&](auto entry) { return entry.getDataClauseOperands(); })
5194 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5195 return dataOperands;
5196}
5197
5200 auto dataOperands{
5203 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5204 .Default([&](mlir::Operation *) { return nullptr; })};
5205 return dataOperands;
5206}
5207
5208mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5209 auto recipe{
5211 .Case<ACC_DATA_ENTRY_OPS>(
5212 [&](auto entry) { return entry.getRecipeAttr(); })
5213 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5214 return recipe;
5215}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4591
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1459
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3342
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1960
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4446
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:826
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:599
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2457
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:819
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:752
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:583
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:721
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2468
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2373
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:525
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4648
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3151
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2708
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3358
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4353
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:659
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2662
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:543
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:641
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:675
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1224
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3726
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1915
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2504
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2043
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1234
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:667
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:730
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:554
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:567
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2243
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:706
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3757
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4621
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4530
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2356
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2531
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2647
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2310
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1218
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2623
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:796
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3170
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1689
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2692
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2287
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:685
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1929
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2604
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:529
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3297
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2542
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:764
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:619
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1971
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4412
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2293
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2728
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4500
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperandRange operand_range
Definition Operation.h:397
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_CONSTRUCT_OPS
Definition OpenACC.h:60
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:71
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:48
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:56
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5098
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5067
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5086
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5171
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5199
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5116
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5189
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5160
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5181
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5208
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5131
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:169
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5106
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:206
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5153
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5075
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5053
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5145
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.