MLIR 22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
14#include "mlir/IR/Builders.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/Matchers.h"
21#include "mlir/IR/SymbolTable.h"
22#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/SmallSet.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/LogicalResult.h"
27#include <variant>
28
29using namespace mlir;
30using namespace acc;
31
32#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
37
38namespace {
39
40static bool isScalarLikeType(Type type) {
41 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
42}
43
44/// Helper function to attach the `VarName` attribute to an operation
45/// if a variable name is provided.
46static void attachVarNameAttr(Operation *op, OpBuilder &builder,
47 StringRef varName) {
48 if (!varName.empty()) {
49 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
50 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
51 }
52}
53
54template <typename T>
55struct MemRefPointerLikeModel
56 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
57 Type getElementType(Type pointer) const {
58 return cast<T>(pointer).getElementType();
59 }
60
61 mlir::acc::VariableTypeCategory
62 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
63 Type varType) const {
64 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
65 return mappableTy.getTypeCategory(varPtr);
66 }
67 auto memrefTy = cast<T>(pointer);
68 if (!memrefTy.hasRank()) {
69 // This memref is unranked - aka it could have any rank, including a
70 // rank of 0 which could mean scalar. For now, return uncategorized.
71 return mlir::acc::VariableTypeCategory::uncategorized;
72 }
73
74 if (memrefTy.getRank() == 0) {
75 if (isScalarLikeType(memrefTy.getElementType())) {
76 return mlir::acc::VariableTypeCategory::scalar;
77 }
78 // Zero-rank non-scalar - need further analysis to determine the type
79 // category. For now, return uncategorized.
80 return mlir::acc::VariableTypeCategory::uncategorized;
81 }
82
83 // It has a rank - must be an array.
84 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
85 return mlir::acc::VariableTypeCategory::array;
86 }
87
88 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
89 StringRef varName, Type varType, Value originalVar,
90 bool &needsFree) const {
91 auto memrefTy = cast<MemRefType>(pointer);
92
93 // Check if this is a static memref (all dimensions are known) - if yes
94 // then we can generate an alloca operation.
95 if (memrefTy.hasStaticShape()) {
96 needsFree = false; // alloca doesn't need deallocation
97 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
98 attachVarNameAttr(allocaOp, builder, varName);
99 return allocaOp.getResult();
100 }
101
102 // For dynamic memrefs, extract sizes from the original variable if
103 // provided. Otherwise they cannot be handled.
104 if (originalVar && originalVar.getType() == memrefTy &&
105 memrefTy.hasRank()) {
106 SmallVector<Value> dynamicSizes;
107 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
108 if (memrefTy.isDynamicDim(i)) {
109 // Extract the size of dimension i from the original variable
110 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
111 auto dimSize =
112 memref::DimOp::create(builder, loc, originalVar, indexValue);
113 dynamicSizes.push_back(dimSize);
114 }
115 // Note: We only add dynamic sizes to the dynamicSizes array
116 // Static dimensions are handled automatically by AllocOp
117 }
118 needsFree = true; // alloc needs deallocation
119 auto allocOp =
120 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
121 attachVarNameAttr(allocOp, builder, varName);
122 return allocOp.getResult();
123 }
124
125 // TODO: Unranked not yet supported.
126 return {};
127 }
128
129 bool genFree(Type pointer, OpBuilder &builder, Location loc,
130 TypedValue<PointerLikeType> varToFree, Value allocRes,
131 Type varType) const {
132 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
133 // Use allocRes if provided to determine the allocation type
134 Value valueToInspect = allocRes ? allocRes : memrefValue;
135
136 // Walk through casts to find the original allocation
137 Value currentValue = valueToInspect;
138 Operation *originalAlloc = nullptr;
139
140 // Follow the chain of operations to find the original allocation
141 // even if a casted result is provided.
142 while (currentValue) {
143 if (auto *definingOp = currentValue.getDefiningOp()) {
144 // Check if this is an allocation operation
145 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
146 originalAlloc = definingOp;
147 break;
148 }
149
150 // Check if this is a cast operation we can look through
151 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
152 currentValue = castOp.getSource();
153 continue;
154 }
155
156 // Check for other cast-like operations
157 if (auto reinterpretCastOp =
158 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
159 currentValue = reinterpretCastOp.getSource();
160 continue;
161 }
162
163 // If we can't look through this operation, stop
164 break;
165 }
166 // This is a block argument or similar - can't trace further.
167 break;
168 }
169
170 if (originalAlloc) {
171 if (isa<memref::AllocaOp>(originalAlloc)) {
172 // This is an alloca - no dealloc needed, but return true (success)
173 return true;
174 }
175 if (isa<memref::AllocOp>(originalAlloc)) {
176 // This is an alloc - generate dealloc on varToFree
177 memref::DeallocOp::create(builder, loc, memrefValue);
178 return true;
179 }
180 }
181 }
182
183 return false;
184 }
185
186 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
187 TypedValue<PointerLikeType> destination,
188 TypedValue<PointerLikeType> source, Type varType) const {
189 // Generate a copy operation between two memrefs
190 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
191 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
192
193 // As per memref documentation, source and destination must have same
194 // element type and shape in order to be compatible. We do not want to fail
195 // with an IR verification error - thus check that before generating the
196 // copy operation.
197 if (destMemref && srcMemref &&
198 destMemref.getType().getElementType() ==
199 srcMemref.getType().getElementType() &&
200 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
201 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
202 return true;
203 }
204
205 return false;
206 }
207
208 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
210 Type valueType) const {
211 // Load from a memref - only valid for scalar memrefs (rank 0).
212 // This is because the address computation for memrefs is part of the load
213 // (and not computed separately), but the API does not have arguments for
214 // indexing.
215 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
216 if (!memrefValue)
217 return {};
218
219 auto memrefTy = memrefValue.getType();
220
221 // Only load from scalar memrefs (rank 0)
222 if (memrefTy.getRank() != 0)
223 return {};
224
225 return memref::LoadOp::create(builder, loc, memrefValue);
226 }
227
228 bool genStore(Type pointer, OpBuilder &builder, Location loc,
229 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
230 // Store to a memref - only valid for scalar memrefs (rank 0)
231 // This is because the address computation for memrefs is part of the store
232 // (and not computed separately), but the API does not have arguments for
233 // indexing.
234 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
235 if (!memrefValue)
236 return false;
237
238 auto memrefTy = memrefValue.getType();
239
240 // Only store to scalar memrefs (rank 0)
241 if (memrefTy.getRank() != 0)
242 return false;
243
244 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
245 return true;
246 }
247};
248
249struct LLVMPointerPointerLikeModel
250 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
251 LLVM::LLVMPointerType> {
252 Type getElementType(Type pointer) const { return Type(); }
253
254 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
256 Type valueType) const {
257 // For LLVM pointers, we need the valueType to determine what to load
258 if (!valueType)
259 return {};
260
261 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
262 }
263
264 bool genStore(Type pointer, OpBuilder &builder, Location loc,
265 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
266 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
267 return true;
268 }
269};
270
271struct MemrefAddressOfGlobalModel
272 : public AddressOfGlobalOpInterface::ExternalModel<
273 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
274 SymbolRefAttr getSymbol(Operation *op) const {
275 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
276 return getGlobalOp.getNameAttr();
277 }
278};
279
280struct MemrefGlobalVariableModel
281 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
282 memref::GlobalOp> {
283 bool isConstant(Operation *op) const {
284 auto globalOp = cast<memref::GlobalOp>(op);
285 return globalOp.getConstant();
286 }
287
288 Region *getInitRegion(Operation *op) const {
289 // GlobalOp uses attributes for initialization, not regions
290 return nullptr;
291 }
292};
293
294/// Helper function for any of the times we need to modify an ArrayAttr based on
295/// a device type list. Returns a new ArrayAttr with all of the
296/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
297/// list is empty).
298mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
299 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
300 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
302 if (existingDeviceTypes)
303 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
304
305 if (newDeviceTypes.empty())
306 deviceTypes.push_back(
307 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
308
309 for (DeviceType dt : newDeviceTypes)
310 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
311
312 return mlir::ArrayAttr::get(context, deviceTypes);
313}
314
315/// Helper function for any of the times we need to add operands that are
316/// affected by a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones (or an added none, if the
318/// new list is empty). Additionally, adds the arguments to the argCollection
319/// the correct number of times. This will also update a 'segments' array, even
320/// if it won't be used.
321mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
322 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
324 mlir::MutableOperandRange argCollection,
325 llvm::SmallVector<int32_t> &segments) {
327 if (existingDeviceTypes)
328 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
329
330 if (newDeviceTypes.empty()) {
331 argCollection.append(arguments);
332 segments.push_back(arguments.size());
333 deviceTypes.push_back(
334 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
335 }
336
337 for (DeviceType dt : newDeviceTypes) {
338 argCollection.append(arguments);
339 segments.push_back(arguments.size());
340 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
341 }
342
343 return mlir::ArrayAttr::get(context, deviceTypes);
344}
345
346/// Overload for when the 'segments' aren't needed.
347mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
348 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
349 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
350 mlir::MutableOperandRange argCollection) {
352 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
353 newDeviceTypes, arguments,
354 argCollection, segments);
355}
356} // namespace
357
358//===----------------------------------------------------------------------===//
359// OpenACC operations
360//===----------------------------------------------------------------------===//
361
362void OpenACCDialect::initialize() {
363 addOperations<
364#define GET_OP_LIST
365#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
366 >();
367 addAttributes<
368#define GET_ATTRDEF_LIST
369#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
370 >();
371 addTypes<
372#define GET_TYPEDEF_LIST
373#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
374 >();
375
376 // By attaching interfaces here, we make the OpenACC dialect dependent on
377 // the other dialects. This is probably better than having dialects like LLVM
378 // and memref be dependent on OpenACC.
379 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
380 *getContext());
381 UnrankedMemRefType::attachInterface<
382 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
383 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
384 *getContext());
385
386 // Attach operation interfaces
387 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
388 *getContext());
389 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
390}
391
392//===----------------------------------------------------------------------===//
393// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
394// acc.kernel_environment / acc.data / acc.host_data
395//===----------------------------------------------------------------------===//
396
397/// Generic helper for single-region OpenACC ops that execute their body once
398/// and then return to the parent operation with their results (if any).
399static void
401 RegionBranchPoint point,
403 if (point.isParent()) {
404 regions.push_back(RegionSuccessor(&region));
405 return;
406 }
407
408 regions.push_back(RegionSuccessor(op, op->getResults()));
409}
410
411void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
413 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
414 regions);
415}
416
417void ParallelOp::getSuccessorRegions(
419 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
420 regions);
421}
422
423void SerialOp::getSuccessorRegions(RegionBranchPoint point,
425 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
426 regions);
427}
428
429void KernelEnvironmentOp::getSuccessorRegions(
431 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
432 regions);
433}
434
435void DataOp::getSuccessorRegions(RegionBranchPoint point,
437 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
438 regions);
439}
440
441void HostDataOp::getSuccessorRegions(
443 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
444 regions);
445}
446
447//===----------------------------------------------------------------------===//
448// device_type support helpers
449//===----------------------------------------------------------------------===//
450
451static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
452 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
453}
454
455static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
456 mlir::acc::DeviceType deviceType) {
457 if (!hasDeviceTypeValues(arrayAttr))
458 return false;
459
460 for (auto attr : *arrayAttr) {
461 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
462 if (deviceTypeAttr.getValue() == deviceType)
463 return true;
464 }
465
466 return false;
467}
468
470 std::optional<mlir::ArrayAttr> deviceTypes) {
471 if (!hasDeviceTypeValues(deviceTypes))
472 return;
473
474 p << "[";
475 llvm::interleaveComma(*deviceTypes, p,
476 [&](mlir::Attribute attr) { p << attr; });
477 p << "]";
478}
479
480static std::optional<unsigned> findSegment(ArrayAttr segments,
481 mlir::acc::DeviceType deviceType) {
482 unsigned segmentIdx = 0;
483 for (auto attr : segments) {
484 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
485 if (deviceTypeAttr.getValue() == deviceType)
486 return std::make_optional(segmentIdx);
487 ++segmentIdx;
488 }
489 return std::nullopt;
490}
491
493getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
495 std::optional<llvm::ArrayRef<int32_t>> segments,
496 mlir::acc::DeviceType deviceType) {
497 if (!arrayAttr)
498 return range.take_front(0);
499 if (auto pos = findSegment(*arrayAttr, deviceType)) {
500 int32_t nbOperandsBefore = 0;
501 for (unsigned i = 0; i < *pos; ++i)
502 nbOperandsBefore += (*segments)[i];
503 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
504 }
505 return range.take_front(0);
506}
507
508static mlir::Value
509getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
511 std::optional<llvm::ArrayRef<int32_t>> segments,
512 std::optional<mlir::ArrayAttr> hasWaitDevnum,
513 mlir::acc::DeviceType deviceType) {
514 if (!hasDeviceTypeValues(deviceTypeAttr))
515 return {};
516 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
517 if (hasWaitDevnum->getValue()[*pos])
518 return getValuesFromSegments(deviceTypeAttr, operands, segments,
519 deviceType)
520 .front();
521 return {};
522}
523
525getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
527 std::optional<llvm::ArrayRef<int32_t>> segments,
528 std::optional<mlir::ArrayAttr> hasWaitDevnum,
529 mlir::acc::DeviceType deviceType) {
530 auto range =
531 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
532 if (range.empty())
533 return range;
534 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
535 if (hasWaitDevnum && *hasWaitDevnum) {
536 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
537 if (boolAttr.getValue())
538 return range.drop_front(1); // first value is devnum
539 }
540 }
541 return range;
542}
543
544template <typename Op>
545static LogicalResult checkWaitAndAsyncConflict(Op op) {
546 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
547 ++dtypeInt) {
548 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
549
550 // The asyncOnly attribute represent the async clause without value.
551 // Therefore the attribute and operand cannot appear at the same time.
552 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
553 op.hasAsyncOnly(dtype))
554 return op.emitError(
555 "asyncOnly attribute cannot appear with asyncOperand");
556
557 // The wait attribute represent the wait clause without values. Therefore
558 // the attribute and operands cannot appear at the same time.
559 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
560 op.hasWaitOnly(dtype))
561 return op.emitError("wait attribute cannot appear with waitOperands");
562 }
563 return success();
564}
565
566template <typename Op>
567static LogicalResult checkVarAndVarType(Op op) {
568 if (!op.getVar())
569 return op.emitError("must have var operand");
570
571 // A variable must have a type that is either pointer-like or mappable.
572 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
573 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
574 return op.emitError("var must be mappable or pointer-like");
575
576 // When it is a pointer-like type, the varType must capture the target type.
577 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
578 op.getVarType() == op.getVar().getType())
579 return op.emitError("varType must capture the element type of var");
580
581 return success();
582}
583
584template <typename Op>
585static LogicalResult checkVarAndAccVar(Op op) {
586 if (op.getVar().getType() != op.getAccVar().getType())
587 return op.emitError("input and output types must match");
588
589 return success();
590}
591
592template <typename Op>
593static LogicalResult checkNoModifier(Op op) {
594 if (op.getModifiers() != acc::DataClauseModifier::none)
595 return op.emitError("no data clause modifiers are allowed");
596 return success();
597}
598
599template <typename Op>
600static LogicalResult
601checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
602 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
603 return op.emitError(
604 "invalid data clause modifiers: " +
605 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
606
607 return success();
608}
609
610template <typename OpT, typename RecipeOpT>
611static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
612 // Mappable types do not need a recipe because it is possible to generate one
613 // from its API. Reject reductions though because no API is available for them
614 // at this time.
615 if (mlir::acc::isMappableType(op.getVar().getType()) &&
616 !std::is_same_v<OpT, acc::ReductionOp>)
617 return success();
618
619 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
620 if (!operandRecipe)
621 return op->emitOpError() << "recipe expected for " << operandName;
622
623 auto decl =
625 if (!decl)
626 return op->emitOpError()
627 << "expected symbol reference " << operandRecipe << " to point to a "
628 << operandName << " declaration";
629 return success();
630}
631
632static ParseResult parseVar(mlir::OpAsmParser &parser,
634 // Either `var` or `varPtr` keyword is required.
635 if (failed(parser.parseOptionalKeyword("varPtr"))) {
636 if (failed(parser.parseKeyword("var")))
637 return failure();
638 }
639 if (failed(parser.parseLParen()))
640 return failure();
641 if (failed(parser.parseOperand(var)))
642 return failure();
643
644 return success();
645}
646
648 mlir::Value var) {
649 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
650 p << "varPtr(";
651 else
652 p << "var(";
653 p.printOperand(var);
654}
655
656static ParseResult parseAccVar(mlir::OpAsmParser &parser,
658 mlir::Type &accVarType) {
659 // Either `accVar` or `accPtr` keyword is required.
660 if (failed(parser.parseOptionalKeyword("accPtr"))) {
661 if (failed(parser.parseKeyword("accVar")))
662 return failure();
663 }
664 if (failed(parser.parseLParen()))
665 return failure();
666 if (failed(parser.parseOperand(var)))
667 return failure();
668 if (failed(parser.parseColon()))
669 return failure();
670 if (failed(parser.parseType(accVarType)))
671 return failure();
672 if (failed(parser.parseRParen()))
673 return failure();
674
675 return success();
676}
677
679 mlir::Value accVar, mlir::Type accVarType) {
680 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
681 p << "accPtr(";
682 else
683 p << "accVar(";
684 p.printOperand(accVar);
685 p << " : ";
686 p.printType(accVarType);
687 p << ")";
688}
689
690static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
691 mlir::Type &varPtrType,
692 mlir::TypeAttr &varTypeAttr) {
693 if (failed(parser.parseType(varPtrType)))
694 return failure();
695 if (failed(parser.parseRParen()))
696 return failure();
697
698 if (succeeded(parser.parseOptionalKeyword("varType"))) {
699 if (failed(parser.parseLParen()))
700 return failure();
701 mlir::Type varType;
702 if (failed(parser.parseType(varType)))
703 return failure();
704 varTypeAttr = mlir::TypeAttr::get(varType);
705 if (failed(parser.parseRParen()))
706 return failure();
707 } else {
708 // Set `varType` from the element type of the type of `varPtr`.
709 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
710 varTypeAttr = mlir::TypeAttr::get(
711 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
712 else
713 varTypeAttr = mlir::TypeAttr::get(varPtrType);
714 }
715
716 return success();
717}
718
720 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
721 p.printType(varPtrType);
722 p << ")";
723
724 // Print the `varType` only if it differs from the element type of
725 // `varPtr`'s type.
726 mlir::Type varType = varTypeAttr.getValue();
727 mlir::Type typeToCheckAgainst =
728 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
729 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
730 : varPtrType;
731 if (typeToCheckAgainst != varType) {
732 p << " varType(";
733 p.printType(varType);
734 p << ")";
735 }
736}
737
738static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
739 mlir::SymbolRefAttr &recipeAttr) {
740 if (failed(parser.parseAttribute(recipeAttr)))
741 return failure();
742 return success();
743}
744
746 mlir::SymbolRefAttr recipeAttr) {
747 p << recipeAttr;
748}
749
750//===----------------------------------------------------------------------===//
751// DataBoundsOp
752//===----------------------------------------------------------------------===//
753LogicalResult acc::DataBoundsOp::verify() {
754 auto extent = getExtent();
755 auto upperbound = getUpperbound();
756 if (!extent && !upperbound)
757 return emitError("expected extent or upperbound.");
758 return success();
759}
760
761//===----------------------------------------------------------------------===//
762// PrivateOp
763//===----------------------------------------------------------------------===//
764LogicalResult acc::PrivateOp::verify() {
765 if (getDataClause() != acc::DataClause::acc_private)
766 return emitError(
767 "data clause associated with private operation must match its intent");
768 if (failed(checkVarAndVarType(*this)))
769 return failure();
770 if (failed(checkNoModifier(*this)))
771 return failure();
772 if (failed(
774 return failure();
775 return success();
776}
777
778//===----------------------------------------------------------------------===//
779// FirstprivateOp
780//===----------------------------------------------------------------------===//
781LogicalResult acc::FirstprivateOp::verify() {
782 if (getDataClause() != acc::DataClause::acc_firstprivate)
783 return emitError("data clause associated with firstprivate operation must "
784 "match its intent");
785 if (failed(checkVarAndVarType(*this)))
786 return failure();
787 if (failed(checkNoModifier(*this)))
788 return failure();
790 *this, "firstprivate")))
791 return failure();
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// FirstprivateMapInitialOp
797//===----------------------------------------------------------------------===//
798LogicalResult acc::FirstprivateMapInitialOp::verify() {
799 if (getDataClause() != acc::DataClause::acc_firstprivate)
800 return emitError("data clause associated with firstprivate operation must "
801 "match its intent");
802 if (failed(checkVarAndVarType(*this)))
803 return failure();
804 if (failed(checkNoModifier(*this)))
805 return failure();
806 return success();
807}
808
809//===----------------------------------------------------------------------===//
810// ReductionOp
811//===----------------------------------------------------------------------===//
812LogicalResult acc::ReductionOp::verify() {
813 if (getDataClause() != acc::DataClause::acc_reduction)
814 return emitError("data clause associated with reduction operation must "
815 "match its intent");
816 if (failed(checkVarAndVarType(*this)))
817 return failure();
818 if (failed(checkNoModifier(*this)))
819 return failure();
821 *this, "reduction")))
822 return failure();
823 return success();
824}
825
826//===----------------------------------------------------------------------===//
827// DevicePtrOp
828//===----------------------------------------------------------------------===//
829LogicalResult acc::DevicePtrOp::verify() {
830 if (getDataClause() != acc::DataClause::acc_deviceptr)
831 return emitError("data clause associated with deviceptr operation must "
832 "match its intent");
833 if (failed(checkVarAndVarType(*this)))
834 return failure();
835 if (failed(checkVarAndAccVar(*this)))
836 return failure();
837 if (failed(checkNoModifier(*this)))
838 return failure();
839 return success();
840}
841
842//===----------------------------------------------------------------------===//
843// PresentOp
844//===----------------------------------------------------------------------===//
845LogicalResult acc::PresentOp::verify() {
846 if (getDataClause() != acc::DataClause::acc_present)
847 return emitError(
848 "data clause associated with present operation must match its intent");
849 if (failed(checkVarAndVarType(*this)))
850 return failure();
851 if (failed(checkVarAndAccVar(*this)))
852 return failure();
853 if (failed(checkNoModifier(*this)))
854 return failure();
855 return success();
856}
857
858//===----------------------------------------------------------------------===//
859// CopyinOp
860//===----------------------------------------------------------------------===//
861LogicalResult acc::CopyinOp::verify() {
862 // Test for all clauses this operation can be decomposed from:
863 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
864 getDataClause() != acc::DataClause::acc_copyin_readonly &&
865 getDataClause() != acc::DataClause::acc_copy &&
866 getDataClause() != acc::DataClause::acc_reduction)
867 return emitError(
868 "data clause associated with copyin operation must match its intent"
869 " or specify original clause this operation was decomposed from");
870 if (failed(checkVarAndVarType(*this)))
871 return failure();
872 if (failed(checkVarAndAccVar(*this)))
873 return failure();
874 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
875 acc::DataClauseModifier::always |
876 acc::DataClauseModifier::capture)))
877 return failure();
878 return success();
879}
880
881bool acc::CopyinOp::isCopyinReadonly() {
882 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
883 acc::bitEnumContainsAny(getModifiers(),
884 acc::DataClauseModifier::readonly);
885}
886
887//===----------------------------------------------------------------------===//
888// CreateOp
889//===----------------------------------------------------------------------===//
890LogicalResult acc::CreateOp::verify() {
891 // Test for all clauses this operation can be decomposed from:
892 if (getDataClause() != acc::DataClause::acc_create &&
893 getDataClause() != acc::DataClause::acc_create_zero &&
894 getDataClause() != acc::DataClause::acc_copyout &&
895 getDataClause() != acc::DataClause::acc_copyout_zero)
896 return emitError(
897 "data clause associated with create operation must match its intent"
898 " or specify original clause this operation was decomposed from");
899 if (failed(checkVarAndVarType(*this)))
900 return failure();
901 if (failed(checkVarAndAccVar(*this)))
902 return failure();
903 // this op is the entry part of copyout, so it also needs to allow all
904 // modifiers allowed on copyout.
905 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
906 acc::DataClauseModifier::always |
907 acc::DataClauseModifier::capture)))
908 return failure();
909 return success();
910}
911
912bool acc::CreateOp::isCreateZero() {
913 // The zero modifier is encoded in the data clause.
914 return getDataClause() == acc::DataClause::acc_create_zero ||
915 getDataClause() == acc::DataClause::acc_copyout_zero ||
916 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
917}
918
919//===----------------------------------------------------------------------===//
920// NoCreateOp
921//===----------------------------------------------------------------------===//
922LogicalResult acc::NoCreateOp::verify() {
923 if (getDataClause() != acc::DataClause::acc_no_create)
924 return emitError("data clause associated with no_create operation must "
925 "match its intent");
926 if (failed(checkVarAndVarType(*this)))
927 return failure();
928 if (failed(checkVarAndAccVar(*this)))
929 return failure();
930 if (failed(checkNoModifier(*this)))
931 return failure();
932 return success();
933}
934
935//===----------------------------------------------------------------------===//
936// AttachOp
937//===----------------------------------------------------------------------===//
938LogicalResult acc::AttachOp::verify() {
939 if (getDataClause() != acc::DataClause::acc_attach)
940 return emitError(
941 "data clause associated with attach operation must match its intent");
942 if (failed(checkVarAndVarType(*this)))
943 return failure();
944 if (failed(checkVarAndAccVar(*this)))
945 return failure();
946 if (failed(checkNoModifier(*this)))
947 return failure();
948 return success();
949}
950
951//===----------------------------------------------------------------------===//
952// DeclareDeviceResidentOp
953//===----------------------------------------------------------------------===//
954
955LogicalResult acc::DeclareDeviceResidentOp::verify() {
956 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
957 return emitError("data clause associated with device_resident operation "
958 "must match its intent");
959 if (failed(checkVarAndVarType(*this)))
960 return failure();
961 if (failed(checkVarAndAccVar(*this)))
962 return failure();
963 if (failed(checkNoModifier(*this)))
964 return failure();
965 return success();
966}
967
968//===----------------------------------------------------------------------===//
969// DeclareLinkOp
970//===----------------------------------------------------------------------===//
971
972LogicalResult acc::DeclareLinkOp::verify() {
973 if (getDataClause() != acc::DataClause::acc_declare_link)
974 return emitError(
975 "data clause associated with link operation must match its intent");
976 if (failed(checkVarAndVarType(*this)))
977 return failure();
978 if (failed(checkVarAndAccVar(*this)))
979 return failure();
980 if (failed(checkNoModifier(*this)))
981 return failure();
982 return success();
983}
984
985//===----------------------------------------------------------------------===//
986// CopyoutOp
987//===----------------------------------------------------------------------===//
988LogicalResult acc::CopyoutOp::verify() {
989 // Test for all clauses this operation can be decomposed from:
990 if (getDataClause() != acc::DataClause::acc_copyout &&
991 getDataClause() != acc::DataClause::acc_copyout_zero &&
992 getDataClause() != acc::DataClause::acc_copy &&
993 getDataClause() != acc::DataClause::acc_reduction)
994 return emitError(
995 "data clause associated with copyout operation must match its intent"
996 " or specify original clause this operation was decomposed from");
997 if (!getVar() || !getAccVar())
998 return emitError("must have both host and device pointers");
999 if (failed(checkVarAndVarType(*this)))
1000 return failure();
1001 if (failed(checkVarAndAccVar(*this)))
1002 return failure();
1003 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1004 acc::DataClauseModifier::always |
1005 acc::DataClauseModifier::capture)))
1006 return failure();
1007 return success();
1008}
1009
1010bool acc::CopyoutOp::isCopyoutZero() {
1011 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1012 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1013}
1014
1015//===----------------------------------------------------------------------===//
1016// DeleteOp
1017//===----------------------------------------------------------------------===//
1018LogicalResult acc::DeleteOp::verify() {
1019 // Test for all clauses this operation can be decomposed from:
1020 if (getDataClause() != acc::DataClause::acc_delete &&
1021 getDataClause() != acc::DataClause::acc_create &&
1022 getDataClause() != acc::DataClause::acc_create_zero &&
1023 getDataClause() != acc::DataClause::acc_copyin &&
1024 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1025 getDataClause() != acc::DataClause::acc_present &&
1026 getDataClause() != acc::DataClause::acc_no_create &&
1027 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1028 getDataClause() != acc::DataClause::acc_declare_link)
1029 return emitError(
1030 "data clause associated with delete operation must match its intent"
1031 " or specify original clause this operation was decomposed from");
1032 if (!getAccVar())
1033 return emitError("must have device pointer");
1034 // This op is the exit part of copyin and create - thus allow all modifiers
1035 // allowed on either case.
1036 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1037 acc::DataClauseModifier::readonly |
1038 acc::DataClauseModifier::always |
1039 acc::DataClauseModifier::capture)))
1040 return failure();
1041 return success();
1042}
1043
1044//===----------------------------------------------------------------------===//
1045// DetachOp
1046//===----------------------------------------------------------------------===//
1047LogicalResult acc::DetachOp::verify() {
1048 // Test for all clauses this operation can be decomposed from:
1049 if (getDataClause() != acc::DataClause::acc_detach &&
1050 getDataClause() != acc::DataClause::acc_attach)
1051 return emitError(
1052 "data clause associated with detach operation must match its intent"
1053 " or specify original clause this operation was decomposed from");
1054 if (!getAccVar())
1055 return emitError("must have device pointer");
1056 if (failed(checkNoModifier(*this)))
1057 return failure();
1058 return success();
1059}
1060
1061//===----------------------------------------------------------------------===//
1062// HostOp
1063//===----------------------------------------------------------------------===//
1064LogicalResult acc::UpdateHostOp::verify() {
1065 // Test for all clauses this operation can be decomposed from:
1066 if (getDataClause() != acc::DataClause::acc_update_host &&
1067 getDataClause() != acc::DataClause::acc_update_self)
1068 return emitError(
1069 "data clause associated with host operation must match its intent"
1070 " or specify original clause this operation was decomposed from");
1071 if (!getVar() || !getAccVar())
1072 return emitError("must have both host and device pointers");
1073 if (failed(checkVarAndVarType(*this)))
1074 return failure();
1075 if (failed(checkVarAndAccVar(*this)))
1076 return failure();
1077 if (failed(checkNoModifier(*this)))
1078 return failure();
1079 return success();
1080}
1081
1082//===----------------------------------------------------------------------===//
1083// DeviceOp
1084//===----------------------------------------------------------------------===//
1085LogicalResult acc::UpdateDeviceOp::verify() {
1086 // Test for all clauses this operation can be decomposed from:
1087 if (getDataClause() != acc::DataClause::acc_update_device)
1088 return emitError(
1089 "data clause associated with device operation must match its intent"
1090 " or specify original clause this operation was decomposed from");
1091 if (failed(checkVarAndVarType(*this)))
1092 return failure();
1093 if (failed(checkVarAndAccVar(*this)))
1094 return failure();
1095 if (failed(checkNoModifier(*this)))
1096 return failure();
1097 return success();
1098}
1099
1100//===----------------------------------------------------------------------===//
1101// UseDeviceOp
1102//===----------------------------------------------------------------------===//
1103LogicalResult acc::UseDeviceOp::verify() {
1104 // Test for all clauses this operation can be decomposed from:
1105 if (getDataClause() != acc::DataClause::acc_use_device)
1106 return emitError(
1107 "data clause associated with use_device operation must match its intent"
1108 " or specify original clause this operation was decomposed from");
1109 if (failed(checkVarAndVarType(*this)))
1110 return failure();
1111 if (failed(checkVarAndAccVar(*this)))
1112 return failure();
1113 if (failed(checkNoModifier(*this)))
1114 return failure();
1115 return success();
1116}
1117
1118//===----------------------------------------------------------------------===//
1119// CacheOp
1120//===----------------------------------------------------------------------===//
1121LogicalResult acc::CacheOp::verify() {
1122 // Test for all clauses this operation can be decomposed from:
1123 if (getDataClause() != acc::DataClause::acc_cache &&
1124 getDataClause() != acc::DataClause::acc_cache_readonly)
1125 return emitError(
1126 "data clause associated with cache operation must match its intent"
1127 " or specify original clause this operation was decomposed from");
1128 if (failed(checkVarAndVarType(*this)))
1129 return failure();
1130 if (failed(checkVarAndAccVar(*this)))
1131 return failure();
1132 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1133 return failure();
1134 return success();
1135}
1136
1137bool acc::CacheOp::isCacheReadonly() {
1138 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1139 acc::bitEnumContainsAny(getModifiers(),
1140 acc::DataClauseModifier::readonly);
1141}
1142
1143template <typename StructureOp>
1144static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1145 unsigned nRegions = 1) {
1146
1148 for (unsigned i = 0; i < nRegions; ++i)
1149 regions.push_back(state.addRegion());
1150
1151 for (Region *region : regions)
1152 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1153 return failure();
1154
1155 return success();
1156}
1157
1159 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1160}
1161
1162namespace {
1163/// Pattern to remove operation without region that have constant false `ifCond`
1164/// and remove the condition from the operation if the `ifCond` is a true
1165/// constant.
1166template <typename OpTy>
1167struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1168 using OpRewritePattern<OpTy>::OpRewritePattern;
1169
1170 LogicalResult matchAndRewrite(OpTy op,
1171 PatternRewriter &rewriter) const override {
1172 // Early return if there is no condition.
1173 Value ifCond = op.getIfCond();
1174 if (!ifCond)
1175 return failure();
1176
1177 IntegerAttr constAttr;
1178 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1179 return failure();
1180 if (constAttr.getInt())
1181 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1182 else
1183 rewriter.eraseOp(op);
1184
1185 return success();
1186 }
1187};
1188
1189/// Replaces the given op with the contents of the given single-block region,
1190/// using the operands of the block terminator to replace operation results.
1191static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1192 Region &region, ValueRange blockArgs = {}) {
1193 assert(region.hasOneBlock() && "expected single-block region");
1194 Block *block = &region.front();
1195 Operation *terminator = block->getTerminator();
1196 ValueRange results = terminator->getOperands();
1197 rewriter.inlineBlockBefore(block, op, blockArgs);
1198 rewriter.replaceOp(op, results);
1199 rewriter.eraseOp(terminator);
1200}
1201
1202/// Pattern to remove operation with region that have constant false `ifCond`
1203/// and remove the condition from the operation if the `ifCond` is constant
1204/// true.
1205template <typename OpTy>
1206struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1207 using OpRewritePattern<OpTy>::OpRewritePattern;
1208
1209 LogicalResult matchAndRewrite(OpTy op,
1210 PatternRewriter &rewriter) const override {
1211 // Early return if there is no condition.
1212 Value ifCond = op.getIfCond();
1213 if (!ifCond)
1214 return failure();
1215
1216 IntegerAttr constAttr;
1217 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1218 return failure();
1219 if (constAttr.getInt())
1220 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1221 else
1222 replaceOpWithRegion(rewriter, op, op.getRegion());
1223
1224 return success();
1225 }
1226};
1227
1228/// Remove empty acc.kernel_environment operations. If the operation has wait
1229/// operands, create a acc.wait operation to preserve synchronization.
1230struct RemoveEmptyKernelEnvironment
1231 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1232 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1233
1234 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1235 PatternRewriter &rewriter) const override {
1236 assert(op->getNumRegions() == 1 && "expected op to have one region");
1237
1238 Block &block = op.getRegion().front();
1239 if (!block.empty())
1240 return failure();
1241
1242 // Conservatively disable canonicalization of empty acc.kernel_environment
1243 // operations if the wait operands in the kernel_environment cannot be fully
1244 // represented by acc.wait operation.
1245
1246 // Disable canonicalization if device type is not the default
1247 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1248 for (auto attr : deviceTypeAttr) {
1249 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1250 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1251 return failure();
1252 }
1253 }
1254 }
1255
1256 // Disable canonicalization if any wait segment has a devnum
1257 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1258 for (auto attr : hasDevnumAttr) {
1259 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1260 if (boolAttr.getValue())
1261 return failure();
1262 }
1263 }
1264 }
1265
1266 // Disable canonicalization if there are multiple wait segments
1267 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1268 if (segmentsAttr.size() > 1)
1269 return failure();
1270 }
1271
1272 // Remove empty kernel environment.
1273 // Preserve synchronization by creating acc.wait operation if needed.
1274 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1275 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1276 /*asyncOperand=*/Value(),
1277 /*waitDevnum=*/Value(),
1278 /*async=*/nullptr,
1279 /*ifCond=*/Value());
1280 else
1281 rewriter.eraseOp(op);
1282
1283 return success();
1284 }
1285};
1286
1287//===----------------------------------------------------------------------===//
1288// Recipe Region Helpers
1289//===----------------------------------------------------------------------===//
1290
1291/// Create and populate an init region for privatization recipes.
1292/// Returns success if the region is populated, failure otherwise.
1293/// Sets needsFree to indicate if the allocated memory requires deallocation.
1294static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1295 Region &initRegion, Type varType,
1296 StringRef varName, ValueRange bounds,
1297 bool &needsFree) {
1298 // Create init block with arguments: original value + bounds
1299 SmallVector<Type> argTypes{varType};
1300 SmallVector<Location> argLocs{loc};
1301 for (Value bound : bounds) {
1302 argTypes.push_back(bound.getType());
1303 argLocs.push_back(loc);
1304 }
1305
1306 Block *initBlock = builder.createBlock(&initRegion);
1307 initBlock->addArguments(argTypes, argLocs);
1308 builder.setInsertionPointToStart(initBlock);
1309
1310 Value privatizedValue;
1311
1312 // Get the block argument that represents the original variable
1313 Value blockArgVar = initBlock->getArgument(0);
1314
1315 // Generate init region body based on variable type
1316 if (isa<MappableType>(varType)) {
1317 auto mappableTy = cast<MappableType>(varType);
1318 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1319 privatizedValue = mappableTy.generatePrivateInit(
1320 builder, loc, typedVar, varName, bounds, {}, needsFree);
1321 if (!privatizedValue)
1322 return failure();
1323 } else {
1324 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1325 auto pointerLikeTy = cast<PointerLikeType>(varType);
1326 // Use PointerLikeType's allocation API with the block argument
1327 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1328 blockArgVar, needsFree);
1329 if (!privatizedValue)
1330 return failure();
1331 }
1332
1333 // Add yield operation to init block
1334 acc::YieldOp::create(builder, loc, privatizedValue);
1335
1336 return success();
1337}
1338
1339/// Create and populate a copy region for firstprivate recipes.
1340/// Returns success if the region is populated, failure otherwise.
1341/// TODO: Handle MappableType - it does not yet have a copy API.
1342static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1343 Region &copyRegion, Type varType,
1344 ValueRange bounds) {
1345 // Create copy block with arguments: original value + privatized value +
1346 // bounds
1347 SmallVector<Type> copyArgTypes{varType, varType};
1348 SmallVector<Location> copyArgLocs{loc, loc};
1349 for (Value bound : bounds) {
1350 copyArgTypes.push_back(bound.getType());
1351 copyArgLocs.push_back(loc);
1352 }
1353
1354 Block *copyBlock = builder.createBlock(&copyRegion);
1355 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1356 builder.setInsertionPointToStart(copyBlock);
1357
1358 bool isMappable = isa<MappableType>(varType);
1359 bool isPointerLike = isa<PointerLikeType>(varType);
1360 // TODO: Handle MappableType - it does not yet have a copy API.
1361 // Otherwise, for now just fallback to pointer-like behavior.
1362 if (isMappable && !isPointerLike)
1363 return failure();
1364
1365 // Generate copy region body based on variable type
1366 if (isPointerLike) {
1367 auto pointerLikeTy = cast<PointerLikeType>(varType);
1368 Value originalArg = copyBlock->getArgument(0);
1369 Value privatizedArg = copyBlock->getArgument(1);
1370
1371 // Generate copy operation using PointerLikeType interface
1372 if (!pointerLikeTy.genCopy(
1373 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1374 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1375 return failure();
1376 }
1377
1378 // Add terminator to copy block
1379 acc::TerminatorOp::create(builder, loc);
1380
1381 return success();
1382}
1383
1384/// Create and populate a destroy region for privatization recipes.
1385/// Returns success if the region is populated, failure otherwise.
1386static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1387 Region &destroyRegion, Type varType,
1388 Value allocRes, ValueRange bounds) {
1389 // Create destroy block with arguments: original value + privatized value +
1390 // bounds
1391 SmallVector<Type> destroyArgTypes{varType, varType};
1392 SmallVector<Location> destroyArgLocs{loc, loc};
1393 for (Value bound : bounds) {
1394 destroyArgTypes.push_back(bound.getType());
1395 destroyArgLocs.push_back(loc);
1396 }
1397
1398 Block *destroyBlock = builder.createBlock(&destroyRegion);
1399 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1400 builder.setInsertionPointToStart(destroyBlock);
1401
1402 auto varToFree =
1403 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1404 if (isa<MappableType>(varType)) {
1405 auto mappableTy = cast<MappableType>(varType);
1406 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1407 return failure();
1408 } else {
1409 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1410 auto pointerLikeTy = cast<PointerLikeType>(varType);
1411 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1412 return failure();
1413 }
1414
1415 acc::TerminatorOp::create(builder, loc);
1416 return success();
1417}
1418
1419} // namespace
1420
1421//===----------------------------------------------------------------------===//
1422// PrivateRecipeOp
1423//===----------------------------------------------------------------------===//
1424
1426 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1427 Type type, bool verifyYield, bool optional = false) {
1428 if (optional && region.empty())
1429 return success();
1430
1431 if (region.empty())
1432 return op->emitOpError() << "expects non-empty " << regionName << " region";
1433 Block &firstBlock = region.front();
1434 if (firstBlock.getNumArguments() < 1 ||
1435 firstBlock.getArgument(0).getType() != type)
1436 return op->emitOpError() << "expects " << regionName
1437 << " region first "
1438 "argument of the "
1439 << regionType << " type";
1440
1441 if (verifyYield) {
1442 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1443 if (yieldOp.getOperands().size() != 1 ||
1444 yieldOp.getOperands().getTypes()[0] != type)
1445 return op->emitOpError() << "expects " << regionName
1446 << " region to "
1447 "yield a value of the "
1448 << regionType << " type";
1449 }
1450 }
1451 return success();
1452}
1453
1454LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1455 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1456 "privatization", "init", getType(),
1457 /*verifyYield=*/false)))
1458 return failure();
1460 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1461 /*verifyYield=*/false, /*optional=*/true)))
1462 return failure();
1463 return success();
1464}
1465
1466std::optional<PrivateRecipeOp>
1467PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1468 StringRef recipeName, Type varType,
1469 StringRef varName, ValueRange bounds) {
1470 // First, validate that we can handle this variable type
1471 bool isMappable = isa<MappableType>(varType);
1472 bool isPointerLike = isa<PointerLikeType>(varType);
1473
1474 // Unsupported type
1475 if (!isMappable && !isPointerLike)
1476 return std::nullopt;
1477
1478 OpBuilder::InsertionGuard guard(builder);
1479
1480 // Create the recipe operation first so regions have proper parent context
1481 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1482
1483 // Populate the init region
1484 bool needsFree = false;
1485 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1486 varName, bounds, needsFree))) {
1487 recipe.erase();
1488 return std::nullopt;
1489 }
1490
1491 // Only create destroy region if the allocation needs deallocation
1492 if (needsFree) {
1493 // Extract the allocated value from the init block's yield operation
1494 auto yieldOp =
1495 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1496 Value allocRes = yieldOp.getOperand(0);
1497
1498 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1499 varType, allocRes, bounds))) {
1500 recipe.erase();
1501 return std::nullopt;
1502 }
1503 }
1504
1505 return recipe;
1506}
1507
1508std::optional<PrivateRecipeOp>
1509PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1510 StringRef recipeName,
1511 FirstprivateRecipeOp firstprivRecipe) {
1512 // Create the private.recipe op with the same type as the firstprivate.recipe.
1513 OpBuilder::InsertionGuard guard(builder);
1514 auto varType = firstprivRecipe.getType();
1515 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1516
1517 // Clone the init region
1518 IRMapping mapping;
1519 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1520
1521 // Clone destroy region if the firstprivate.recipe has one.
1522 if (!firstprivRecipe.getDestroyRegion().empty()) {
1523 IRMapping mapping;
1524 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1525 mapping);
1526 }
1527 return recipe;
1528}
1529
1530//===----------------------------------------------------------------------===//
1531// FirstprivateRecipeOp
1532//===----------------------------------------------------------------------===//
1533
1534LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1535 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1536 "privatization", "init", getType(),
1537 /*verifyYield=*/false)))
1538 return failure();
1539
1540 if (getCopyRegion().empty())
1541 return emitOpError() << "expects non-empty copy region";
1542
1543 Block &firstBlock = getCopyRegion().front();
1544 if (firstBlock.getNumArguments() < 2 ||
1545 firstBlock.getArgument(0).getType() != getType())
1546 return emitOpError() << "expects copy region with two arguments of the "
1547 "privatization type";
1548
1549 if (getDestroyRegion().empty())
1550 return success();
1551
1552 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1553 "privatization", "destroy",
1554 getType(), /*verifyYield=*/false)))
1555 return failure();
1556
1557 return success();
1558}
1559
1560std::optional<FirstprivateRecipeOp>
1561FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1562 StringRef recipeName, Type varType,
1563 StringRef varName, ValueRange bounds) {
1564 // First, validate that we can handle this variable type
1565 bool isMappable = isa<MappableType>(varType);
1566 bool isPointerLike = isa<PointerLikeType>(varType);
1567
1568 // Unsupported type
1569 if (!isMappable && !isPointerLike)
1570 return std::nullopt;
1571
1572 OpBuilder::InsertionGuard guard(builder);
1573
1574 // Create the recipe operation first so regions have proper parent context
1575 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1576
1577 // Populate the init region
1578 bool needsFree = false;
1579 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1580 varName, bounds, needsFree))) {
1581 recipe.erase();
1582 return std::nullopt;
1583 }
1584
1585 // Populate the copy region
1586 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1587 bounds))) {
1588 recipe.erase();
1589 return std::nullopt;
1590 }
1591
1592 // Only create destroy region if the allocation needs deallocation
1593 if (needsFree) {
1594 // Extract the allocated value from the init block's yield operation
1595 auto yieldOp =
1596 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1597 Value allocRes = yieldOp.getOperand(0);
1598
1599 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1600 varType, allocRes, bounds))) {
1601 recipe.erase();
1602 return std::nullopt;
1603 }
1604 }
1605
1606 return recipe;
1607}
1608
1609//===----------------------------------------------------------------------===//
1610// ReductionRecipeOp
1611//===----------------------------------------------------------------------===//
1612
1613LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1614 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1615 "init", getType(),
1616 /*verifyYield=*/false)))
1617 return failure();
1618
1619 if (getCombinerRegion().empty())
1620 return emitOpError() << "expects non-empty combiner region";
1621
1622 Block &reductionBlock = getCombinerRegion().front();
1623 if (reductionBlock.getNumArguments() < 2 ||
1624 reductionBlock.getArgument(0).getType() != getType() ||
1625 reductionBlock.getArgument(1).getType() != getType())
1626 return emitOpError() << "expects combiner region with the first two "
1627 << "arguments of the reduction type";
1628
1629 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1630 if (yieldOp.getOperands().size() != 1 ||
1631 yieldOp.getOperands().getTypes()[0] != getType())
1632 return emitOpError() << "expects combiner region to yield a value "
1633 "of the reduction type";
1634 }
1635
1636 return success();
1637}
1638
1639//===----------------------------------------------------------------------===//
1640// ParallelOp
1641//===----------------------------------------------------------------------===//
1642
1643/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1644template <typename Op>
1645static LogicalResult checkDataOperands(Op op,
1646 const mlir::ValueRange &operands) {
1647 for (mlir::Value operand : operands)
1648 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1649 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1650 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1651 operand.getDefiningOp()))
1652 return op.emitError(
1653 "expect data entry/exit operation or acc.getdeviceptr "
1654 "as defining op");
1655 return success();
1656}
1657
1658template <typename OpT, typename RecipeOpT>
1659static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1660 const mlir::ValueRange &operands,
1661 llvm::StringRef operandName) {
1663 for (mlir::Value operand : operands) {
1664 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1665 return accConstructOp->emitOpError()
1666 << "expected " << operandName << " as defining op";
1667 if (!set.insert(operand).second)
1668 return accConstructOp->emitOpError()
1669 << operandName << " operand appears more than once";
1670 }
1671 return success();
1672}
1673
1674unsigned ParallelOp::getNumDataOperands() {
1675 return getReductionOperands().size() + getPrivateOperands().size() +
1676 getFirstprivateOperands().size() + getDataClauseOperands().size();
1677}
1678
1679Value ParallelOp::getDataOperand(unsigned i) {
1680 unsigned numOptional = getAsyncOperands().size();
1681 numOptional += getNumGangs().size();
1682 numOptional += getNumWorkers().size();
1683 numOptional += getVectorLength().size();
1684 numOptional += getIfCond() ? 1 : 0;
1685 numOptional += getSelfCond() ? 1 : 0;
1686 return getOperand(getWaitOperands().size() + numOptional + i);
1687}
1688
1689template <typename Op>
1690static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1691 ArrayAttr deviceTypes,
1692 llvm::StringRef keyword) {
1693 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1694 return op.emitOpError() << keyword << " operands count must match "
1695 << keyword << " device_type count";
1696 return success();
1697}
1698
1699template <typename Op>
1701 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1702 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1703 std::size_t numOperandsInSegments = 0;
1704 std::size_t nbOfSegments = 0;
1705
1706 if (segments) {
1707 for (auto segCount : segments.asArrayRef()) {
1708 if (maxInSegment != 0 && segCount > maxInSegment)
1709 return op.emitOpError() << keyword << " expects a maximum of "
1710 << maxInSegment << " values per segment";
1711 numOperandsInSegments += segCount;
1712 ++nbOfSegments;
1713 }
1714 }
1715
1716 if ((numOperandsInSegments != operands.size()) ||
1717 (!deviceTypes && !operands.empty()))
1718 return op.emitOpError()
1719 << keyword << " operand count does not match count in segments";
1720 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1721 return op.emitOpError()
1722 << keyword << " segment count does not match device_type count";
1723 return success();
1724}
1725
1726LogicalResult acc::ParallelOp::verify() {
1727 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1728 mlir::acc::PrivateRecipeOp>(
1729 *this, getPrivateOperands(), "private")))
1730 return failure();
1731 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1732 mlir::acc::FirstprivateRecipeOp>(
1733 *this, getFirstprivateOperands(), "firstprivate")))
1734 return failure();
1735 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1736 mlir::acc::ReductionRecipeOp>(
1737 *this, getReductionOperands(), "reduction")))
1738 return failure();
1739
1741 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1742 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1743 return failure();
1744
1746 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1747 getWaitOperandsDeviceTypeAttr(), "wait")))
1748 return failure();
1749
1750 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1751 getNumWorkersDeviceTypeAttr(),
1752 "num_workers")))
1753 return failure();
1754
1755 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1756 getVectorLengthDeviceTypeAttr(),
1757 "vector_length")))
1758 return failure();
1759
1761 getAsyncOperandsDeviceTypeAttr(),
1762 "async")))
1763 return failure();
1764
1766 return failure();
1767
1768 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1769}
1770
1771static mlir::Value
1772getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1774 mlir::acc::DeviceType deviceType) {
1775 if (!arrayAttr)
1776 return {};
1777 if (auto pos = findSegment(*arrayAttr, deviceType))
1778 return range[*pos];
1779 return {};
1780}
1781
1782bool acc::ParallelOp::hasAsyncOnly() {
1783 return hasAsyncOnly(mlir::acc::DeviceType::None);
1784}
1785
1786bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1787 return hasDeviceType(getAsyncOnly(), deviceType);
1788}
1789
1790mlir::Value acc::ParallelOp::getAsyncValue() {
1791 return getAsyncValue(mlir::acc::DeviceType::None);
1792}
1793
1794mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1796 getAsyncOperands(), deviceType);
1797}
1798
1799mlir::Value acc::ParallelOp::getNumWorkersValue() {
1800 return getNumWorkersValue(mlir::acc::DeviceType::None);
1801}
1802
1804acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1805 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1806 deviceType);
1807}
1808
1809mlir::Value acc::ParallelOp::getVectorLengthValue() {
1810 return getVectorLengthValue(mlir::acc::DeviceType::None);
1811}
1812
1814acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1815 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1816 getVectorLength(), deviceType);
1817}
1818
1819mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1820 return getNumGangsValues(mlir::acc::DeviceType::None);
1821}
1822
1824ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1825 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1826 getNumGangsSegments(), deviceType);
1827}
1828
1829bool acc::ParallelOp::hasWaitOnly() {
1830 return hasWaitOnly(mlir::acc::DeviceType::None);
1831}
1832
1833bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1834 return hasDeviceType(getWaitOnly(), deviceType);
1835}
1836
1837mlir::Operation::operand_range ParallelOp::getWaitValues() {
1838 return getWaitValues(mlir::acc::DeviceType::None);
1839}
1840
1842ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1844 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1845 getHasWaitDevnum(), deviceType);
1846}
1847
1848mlir::Value ParallelOp::getWaitDevnum() {
1849 return getWaitDevnum(mlir::acc::DeviceType::None);
1850}
1851
1852mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1853 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1854 getWaitOperandsSegments(), getHasWaitDevnum(),
1855 deviceType);
1856}
1857
1858void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1859 mlir::OperationState &odsState,
1860 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1861 mlir::ValueRange vectorLength,
1862 mlir::ValueRange asyncOperands,
1863 mlir::ValueRange waitOperands, mlir::Value ifCond,
1864 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1865 mlir::ValueRange gangPrivateOperands,
1866 mlir::ValueRange gangFirstPrivateOperands,
1867 mlir::ValueRange dataClauseOperands) {
1868 ParallelOp::build(
1869 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1870 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1871 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1872 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1873 /*numGangsDeviceType=*/nullptr, numWorkers,
1874 /*numWorkersDeviceType=*/nullptr, vectorLength,
1875 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1876 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
1877 gangFirstPrivateOperands, dataClauseOperands,
1878 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1879}
1880
1881void acc::ParallelOp::addNumWorkersOperand(
1882 MLIRContext *context, mlir::Value newValue,
1883 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1884 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1885 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1886 getNumWorkersMutable()));
1887}
1888void acc::ParallelOp::addVectorLengthOperand(
1889 MLIRContext *context, mlir::Value newValue,
1890 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1891 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1892 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1893 getVectorLengthMutable()));
1894}
1895
1896void acc::ParallelOp::addAsyncOnly(
1897 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1898 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1899 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1900}
1901
1902void acc::ParallelOp::addAsyncOperand(
1903 MLIRContext *context, mlir::Value newValue,
1904 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1905 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1906 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1907 getAsyncOperandsMutable()));
1908}
1909
1910void acc::ParallelOp::addNumGangsOperands(
1911 MLIRContext *context, mlir::ValueRange newValues,
1912 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1914 if (getNumGangsSegments())
1915 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1916
1917 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1918 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1919 getNumGangsMutable(), segments));
1920
1921 setNumGangsSegments(segments);
1922}
1923void acc::ParallelOp::addWaitOnly(
1924 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1925 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1926 effectiveDeviceTypes));
1927}
1928void acc::ParallelOp::addWaitOperands(
1929 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1930 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1931
1933 if (getWaitOperandsSegments())
1934 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1935
1936 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1937 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1938 getWaitOperandsMutable(), segments));
1939 setWaitOperandsSegments(segments);
1940
1942 if (getHasWaitDevnumAttr())
1943 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1944 hasDevnums.insert(
1945 hasDevnums.end(),
1946 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1947 mlir::BoolAttr::get(context, hasDevnum));
1948 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1949}
1950
1951void acc::ParallelOp::addPrivatization(MLIRContext *context,
1952 mlir::acc::PrivateOp op,
1953 mlir::acc::PrivateRecipeOp recipe) {
1954 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1955 getPrivateOperandsMutable().append(op.getResult());
1956}
1957
1958void acc::ParallelOp::addFirstPrivatization(
1959 MLIRContext *context, mlir::acc::FirstprivateOp op,
1960 mlir::acc::FirstprivateRecipeOp recipe) {
1961 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1962 getFirstprivateOperandsMutable().append(op.getResult());
1963}
1964
1965void acc::ParallelOp::addReduction(MLIRContext *context,
1966 mlir::acc::ReductionOp op,
1967 mlir::acc::ReductionRecipeOp recipe) {
1968 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1969 getReductionOperandsMutable().append(op.getResult());
1970}
1971
1972static ParseResult parseNumGangs(
1973 mlir::OpAsmParser &parser,
1975 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1976 mlir::DenseI32ArrayAttr &segments) {
1979
1980 do {
1981 if (failed(parser.parseLBrace()))
1982 return failure();
1983
1984 int32_t crtOperandsSize = operands.size();
1985 if (failed(parser.parseCommaSeparatedList(
1987 if (parser.parseOperand(operands.emplace_back()) ||
1988 parser.parseColonType(types.emplace_back()))
1989 return failure();
1990 return success();
1991 })))
1992 return failure();
1993 seg.push_back(operands.size() - crtOperandsSize);
1994
1995 if (failed(parser.parseRBrace()))
1996 return failure();
1997
1998 if (succeeded(parser.parseOptionalLSquare())) {
1999 if (parser.parseAttribute(attributes.emplace_back()) ||
2000 parser.parseRSquare())
2001 return failure();
2002 } else {
2003 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2004 parser.getContext(), mlir::acc::DeviceType::None));
2005 }
2006 } while (succeeded(parser.parseOptionalComma()));
2007
2008 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2009 attributes.end());
2010 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2011 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2012
2013 return success();
2014}
2015
2017 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2018 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2019 p << " [" << attr << "]";
2020}
2021
2023 mlir::OperandRange operands, mlir::TypeRange types,
2024 std::optional<mlir::ArrayAttr> deviceTypes,
2025 std::optional<mlir::DenseI32ArrayAttr> segments) {
2026 unsigned opIdx = 0;
2027 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2028 p << "{";
2029 llvm::interleaveComma(
2030 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2031 p << operands[opIdx] << " : " << operands[opIdx].getType();
2032 ++opIdx;
2033 });
2034 p << "}";
2035 printSingleDeviceType(p, it.value());
2036 });
2037}
2038
2040 mlir::OpAsmParser &parser,
2042 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2043 mlir::DenseI32ArrayAttr &segments) {
2046
2047 do {
2048 if (failed(parser.parseLBrace()))
2049 return failure();
2050
2051 int32_t crtOperandsSize = operands.size();
2052
2053 if (failed(parser.parseCommaSeparatedList(
2055 if (parser.parseOperand(operands.emplace_back()) ||
2056 parser.parseColonType(types.emplace_back()))
2057 return failure();
2058 return success();
2059 })))
2060 return failure();
2061
2062 seg.push_back(operands.size() - crtOperandsSize);
2063
2064 if (failed(parser.parseRBrace()))
2065 return failure();
2066
2067 if (succeeded(parser.parseOptionalLSquare())) {
2068 if (parser.parseAttribute(attributes.emplace_back()) ||
2069 parser.parseRSquare())
2070 return failure();
2071 } else {
2072 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2073 parser.getContext(), mlir::acc::DeviceType::None));
2074 }
2075 } while (succeeded(parser.parseOptionalComma()));
2076
2077 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2078 attributes.end());
2079 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2080 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2081
2082 return success();
2083}
2084
2087 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2088 std::optional<mlir::DenseI32ArrayAttr> segments) {
2089 unsigned opIdx = 0;
2090 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2091 p << "{";
2092 llvm::interleaveComma(
2093 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2094 p << operands[opIdx] << " : " << operands[opIdx].getType();
2095 ++opIdx;
2096 });
2097 p << "}";
2098 printSingleDeviceType(p, it.value());
2099 });
2100}
2101
2102static ParseResult parseWaitClause(
2103 mlir::OpAsmParser &parser,
2105 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2106 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2107 mlir::ArrayAttr &keywordOnly) {
2108 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2110
2111 bool needCommaBeforeOperands = false;
2112
2113 // Keyword only
2114 if (failed(parser.parseOptionalLParen())) {
2115 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2116 parser.getContext(), mlir::acc::DeviceType::None));
2117 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2118 return success();
2119 }
2120
2121 // Parse keyword only attributes
2122 if (succeeded(parser.parseOptionalLSquare())) {
2123 if (failed(parser.parseCommaSeparatedList([&]() {
2124 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2125 return failure();
2126 return success();
2127 })))
2128 return failure();
2129 if (parser.parseRSquare())
2130 return failure();
2131 needCommaBeforeOperands = true;
2132 }
2133
2134 if (needCommaBeforeOperands && failed(parser.parseComma()))
2135 return failure();
2136
2137 do {
2138 if (failed(parser.parseLBrace()))
2139 return failure();
2140
2141 int32_t crtOperandsSize = operands.size();
2142
2143 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2144 if (failed(parser.parseColon()))
2145 return failure();
2146 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2147 } else {
2148 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2149 }
2150
2151 if (failed(parser.parseCommaSeparatedList(
2153 if (parser.parseOperand(operands.emplace_back()) ||
2154 parser.parseColonType(types.emplace_back()))
2155 return failure();
2156 return success();
2157 })))
2158 return failure();
2159
2160 seg.push_back(operands.size() - crtOperandsSize);
2161
2162 if (failed(parser.parseRBrace()))
2163 return failure();
2164
2165 if (succeeded(parser.parseOptionalLSquare())) {
2166 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2167 parser.parseRSquare())
2168 return failure();
2169 } else {
2170 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2171 parser.getContext(), mlir::acc::DeviceType::None));
2172 }
2173 } while (succeeded(parser.parseOptionalComma()));
2174
2175 if (failed(parser.parseRParen()))
2176 return failure();
2177
2178 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2179 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2180 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2181 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2182
2183 return success();
2184}
2185
2186static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2187 if (!hasDeviceTypeValues(attrs))
2188 return false;
2189 if (attrs->size() != 1)
2190 return false;
2191 if (auto deviceTypeAttr =
2192 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2193 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2194 return false;
2195}
2196
2198 mlir::OperandRange operands, mlir::TypeRange types,
2199 std::optional<mlir::ArrayAttr> deviceTypes,
2200 std::optional<mlir::DenseI32ArrayAttr> segments,
2201 std::optional<mlir::ArrayAttr> hasDevNum,
2202 std::optional<mlir::ArrayAttr> keywordOnly) {
2203
2204 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2205 return;
2206
2207 p << "(";
2208
2209 printDeviceTypes(p, keywordOnly);
2210 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2211 p << ", ";
2212
2213 if (hasDeviceTypeValues(deviceTypes)) {
2214 unsigned opIdx = 0;
2215 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2216 p << "{";
2217 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2218 if (boolAttr && boolAttr.getValue())
2219 p << "devnum: ";
2220 llvm::interleaveComma(
2221 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2222 p << operands[opIdx] << " : " << operands[opIdx].getType();
2223 ++opIdx;
2224 });
2225 p << "}";
2226 printSingleDeviceType(p, it.value());
2227 });
2228 }
2229
2230 p << ")";
2231}
2232
2233static ParseResult parseDeviceTypeOperands(
2234 mlir::OpAsmParser &parser,
2236 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2238 if (failed(parser.parseCommaSeparatedList([&]() {
2239 if (parser.parseOperand(operands.emplace_back()) ||
2240 parser.parseColonType(types.emplace_back()))
2241 return failure();
2242 if (succeeded(parser.parseOptionalLSquare())) {
2243 if (parser.parseAttribute(attributes.emplace_back()) ||
2244 parser.parseRSquare())
2245 return failure();
2246 } else {
2247 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2248 parser.getContext(), mlir::acc::DeviceType::None));
2249 }
2250 return success();
2251 })))
2252 return failure();
2253 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2254 attributes.end());
2255 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2256 return success();
2257}
2258
2259static void
2261 mlir::OperandRange operands, mlir::TypeRange types,
2262 std::optional<mlir::ArrayAttr> deviceTypes) {
2263 if (!hasDeviceTypeValues(deviceTypes))
2264 return;
2265 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2266 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2267 printSingleDeviceType(p, std::get<0>(it));
2268 });
2269}
2270
2272 mlir::OpAsmParser &parser,
2274 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2275 mlir::ArrayAttr &keywordOnlyDeviceType) {
2276
2277 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2278 bool needCommaBeforeOperands = false;
2279
2280 if (failed(parser.parseOptionalLParen())) {
2281 // Keyword only
2282 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2283 parser.getContext(), mlir::acc::DeviceType::None));
2284 keywordOnlyDeviceType =
2285 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2286 return success();
2287 }
2288
2289 // Parse keyword only attributes
2290 if (succeeded(parser.parseOptionalLSquare())) {
2291 // Parse keyword only attributes
2292 if (failed(parser.parseCommaSeparatedList([&]() {
2293 if (parser.parseAttribute(
2294 keywordOnlyDeviceTypeAttributes.emplace_back()))
2295 return failure();
2296 return success();
2297 })))
2298 return failure();
2299 if (parser.parseRSquare())
2300 return failure();
2301 needCommaBeforeOperands = true;
2302 }
2303
2304 if (needCommaBeforeOperands && failed(parser.parseComma()))
2305 return failure();
2306
2308 if (failed(parser.parseCommaSeparatedList([&]() {
2309 if (parser.parseOperand(operands.emplace_back()) ||
2310 parser.parseColonType(types.emplace_back()))
2311 return failure();
2312 if (succeeded(parser.parseOptionalLSquare())) {
2313 if (parser.parseAttribute(attributes.emplace_back()) ||
2314 parser.parseRSquare())
2315 return failure();
2316 } else {
2317 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2318 parser.getContext(), mlir::acc::DeviceType::None));
2319 }
2320 return success();
2321 })))
2322 return failure();
2323
2324 if (failed(parser.parseRParen()))
2325 return failure();
2326
2327 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2328 attributes.end());
2329 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2330 return success();
2331}
2332
2335 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2336 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2337
2338 if (operands.begin() == operands.end() &&
2339 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2340 return;
2341 }
2342
2343 p << "(";
2344 printDeviceTypes(p, keywordOnlyDeviceTypes);
2345 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2346 hasDeviceTypeValues(deviceTypes))
2347 p << ", ";
2348 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2349 p << ")";
2350}
2351
2353 mlir::OpAsmParser &parser,
2354 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2355 mlir::Type &operandType, mlir::UnitAttr &attr) {
2356 // Keyword only
2357 if (failed(parser.parseOptionalLParen())) {
2358 attr = mlir::UnitAttr::get(parser.getContext());
2359 return success();
2360 }
2361
2363 if (failed(parser.parseOperand(op)))
2364 return failure();
2365 operand = op;
2366 if (failed(parser.parseColon()))
2367 return failure();
2368 if (failed(parser.parseType(operandType)))
2369 return failure();
2370 if (failed(parser.parseRParen()))
2371 return failure();
2372
2373 return success();
2374}
2375
2377 mlir::Operation *op,
2378 std::optional<mlir::Value> operand,
2379 mlir::Type operandType,
2380 mlir::UnitAttr attr) {
2381 if (attr)
2382 return;
2383
2384 p << "(";
2385 p.printOperand(*operand);
2386 p << " : ";
2387 p.printType(operandType);
2388 p << ")";
2389}
2390
2392 mlir::OpAsmParser &parser,
2394 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2395 // Keyword only
2396 if (failed(parser.parseOptionalLParen())) {
2397 attr = mlir::UnitAttr::get(parser.getContext());
2398 return success();
2399 }
2400
2401 if (failed(parser.parseCommaSeparatedList([&]() {
2402 if (parser.parseOperand(operands.emplace_back()))
2403 return failure();
2404 return success();
2405 })))
2406 return failure();
2407 if (failed(parser.parseColon()))
2408 return failure();
2409 if (failed(parser.parseCommaSeparatedList([&]() {
2410 if (parser.parseType(types.emplace_back()))
2411 return failure();
2412 return success();
2413 })))
2414 return failure();
2415 if (failed(parser.parseRParen()))
2416 return failure();
2417
2418 return success();
2419}
2420
2422 mlir::Operation *op,
2423 mlir::OperandRange operands,
2424 mlir::TypeRange types,
2425 mlir::UnitAttr attr) {
2426 if (attr)
2427 return;
2428
2429 p << "(";
2430 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2431 p << " : ";
2432 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2433 p << ")";
2434}
2435
2436static ParseResult
2438 mlir::acc::CombinedConstructsTypeAttr &attr) {
2439 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2440 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2441 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2442 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2443 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2444 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2445 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2446 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2447 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2448 } else {
2449 parser.emitError(parser.getCurrentLocation(),
2450 "expected compute construct name");
2451 return failure();
2452 }
2453 return success();
2454}
2455
2456static void
2458 mlir::acc::CombinedConstructsTypeAttr attr) {
2459 if (attr) {
2460 switch (attr.getValue()) {
2461 case mlir::acc::CombinedConstructsType::KernelsLoop:
2462 p << "kernels";
2463 break;
2464 case mlir::acc::CombinedConstructsType::ParallelLoop:
2465 p << "parallel";
2466 break;
2467 case mlir::acc::CombinedConstructsType::SerialLoop:
2468 p << "serial";
2469 break;
2470 };
2471 }
2472}
2473
2474//===----------------------------------------------------------------------===//
2475// SerialOp
2476//===----------------------------------------------------------------------===//
2477
2478unsigned SerialOp::getNumDataOperands() {
2479 return getReductionOperands().size() + getPrivateOperands().size() +
2480 getFirstprivateOperands().size() + getDataClauseOperands().size();
2481}
2482
2483Value SerialOp::getDataOperand(unsigned i) {
2484 unsigned numOptional = getAsyncOperands().size();
2485 numOptional += getIfCond() ? 1 : 0;
2486 numOptional += getSelfCond() ? 1 : 0;
2487 return getOperand(getWaitOperands().size() + numOptional + i);
2488}
2489
2490bool acc::SerialOp::hasAsyncOnly() {
2491 return hasAsyncOnly(mlir::acc::DeviceType::None);
2492}
2493
2494bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2495 return hasDeviceType(getAsyncOnly(), deviceType);
2496}
2497
2498mlir::Value acc::SerialOp::getAsyncValue() {
2499 return getAsyncValue(mlir::acc::DeviceType::None);
2500}
2501
2502mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2504 getAsyncOperands(), deviceType);
2505}
2506
2507bool acc::SerialOp::hasWaitOnly() {
2508 return hasWaitOnly(mlir::acc::DeviceType::None);
2509}
2510
2511bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2512 return hasDeviceType(getWaitOnly(), deviceType);
2513}
2514
2515mlir::Operation::operand_range SerialOp::getWaitValues() {
2516 return getWaitValues(mlir::acc::DeviceType::None);
2517}
2518
2520SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2522 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2523 getHasWaitDevnum(), deviceType);
2524}
2525
2526mlir::Value SerialOp::getWaitDevnum() {
2527 return getWaitDevnum(mlir::acc::DeviceType::None);
2528}
2529
2530mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2531 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2532 getWaitOperandsSegments(), getHasWaitDevnum(),
2533 deviceType);
2534}
2535
2536LogicalResult acc::SerialOp::verify() {
2537 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2538 mlir::acc::PrivateRecipeOp>(
2539 *this, getPrivateOperands(), "private")))
2540 return failure();
2541 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2542 mlir::acc::FirstprivateRecipeOp>(
2543 *this, getFirstprivateOperands(), "firstprivate")))
2544 return failure();
2545 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2546 mlir::acc::ReductionRecipeOp>(
2547 *this, getReductionOperands(), "reduction")))
2548 return failure();
2549
2551 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2552 getWaitOperandsDeviceTypeAttr(), "wait")))
2553 return failure();
2554
2556 getAsyncOperandsDeviceTypeAttr(),
2557 "async")))
2558 return failure();
2559
2561 return failure();
2562
2563 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2564}
2565
2566void acc::SerialOp::addAsyncOnly(
2567 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2568 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2569 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2570}
2571
2572void acc::SerialOp::addAsyncOperand(
2573 MLIRContext *context, mlir::Value newValue,
2574 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2575 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2576 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2577 getAsyncOperandsMutable()));
2578}
2579
2580void acc::SerialOp::addWaitOnly(
2581 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2582 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2583 effectiveDeviceTypes));
2584}
2585void acc::SerialOp::addWaitOperands(
2586 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2587 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2588
2590 if (getWaitOperandsSegments())
2591 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2592
2593 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2594 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2595 getWaitOperandsMutable(), segments));
2596 setWaitOperandsSegments(segments);
2597
2599 if (getHasWaitDevnumAttr())
2600 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2601 hasDevnums.insert(
2602 hasDevnums.end(),
2603 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2604 mlir::BoolAttr::get(context, hasDevnum));
2605 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2606}
2607
2608void acc::SerialOp::addPrivatization(MLIRContext *context,
2609 mlir::acc::PrivateOp op,
2610 mlir::acc::PrivateRecipeOp recipe) {
2611 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2612 getPrivateOperandsMutable().append(op.getResult());
2613}
2614
2615void acc::SerialOp::addFirstPrivatization(
2616 MLIRContext *context, mlir::acc::FirstprivateOp op,
2617 mlir::acc::FirstprivateRecipeOp recipe) {
2618 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2619 getFirstprivateOperandsMutable().append(op.getResult());
2620}
2621
2622void acc::SerialOp::addReduction(MLIRContext *context,
2623 mlir::acc::ReductionOp op,
2624 mlir::acc::ReductionRecipeOp recipe) {
2625 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2626 getReductionOperandsMutable().append(op.getResult());
2627}
2628
2629//===----------------------------------------------------------------------===//
2630// KernelsOp
2631//===----------------------------------------------------------------------===//
2632
2633unsigned KernelsOp::getNumDataOperands() {
2634 return getDataClauseOperands().size();
2635}
2636
2637Value KernelsOp::getDataOperand(unsigned i) {
2638 unsigned numOptional = getAsyncOperands().size();
2639 numOptional += getWaitOperands().size();
2640 numOptional += getNumGangs().size();
2641 numOptional += getNumWorkers().size();
2642 numOptional += getVectorLength().size();
2643 numOptional += getIfCond() ? 1 : 0;
2644 numOptional += getSelfCond() ? 1 : 0;
2645 return getOperand(numOptional + i);
2646}
2647
2648bool acc::KernelsOp::hasAsyncOnly() {
2649 return hasAsyncOnly(mlir::acc::DeviceType::None);
2650}
2651
2652bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2653 return hasDeviceType(getAsyncOnly(), deviceType);
2654}
2655
2656mlir::Value acc::KernelsOp::getAsyncValue() {
2657 return getAsyncValue(mlir::acc::DeviceType::None);
2658}
2659
2660mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2662 getAsyncOperands(), deviceType);
2663}
2664
2665mlir::Value acc::KernelsOp::getNumWorkersValue() {
2666 return getNumWorkersValue(mlir::acc::DeviceType::None);
2667}
2668
2670acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2671 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2672 deviceType);
2673}
2674
2675mlir::Value acc::KernelsOp::getVectorLengthValue() {
2676 return getVectorLengthValue(mlir::acc::DeviceType::None);
2677}
2678
2680acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2681 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2682 getVectorLength(), deviceType);
2683}
2684
2685mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2686 return getNumGangsValues(mlir::acc::DeviceType::None);
2687}
2688
2690KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2691 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2692 getNumGangsSegments(), deviceType);
2693}
2694
2695bool acc::KernelsOp::hasWaitOnly() {
2696 return hasWaitOnly(mlir::acc::DeviceType::None);
2697}
2698
2699bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2700 return hasDeviceType(getWaitOnly(), deviceType);
2701}
2702
2703mlir::Operation::operand_range KernelsOp::getWaitValues() {
2704 return getWaitValues(mlir::acc::DeviceType::None);
2705}
2706
2708KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2710 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2711 getHasWaitDevnum(), deviceType);
2712}
2713
2714mlir::Value KernelsOp::getWaitDevnum() {
2715 return getWaitDevnum(mlir::acc::DeviceType::None);
2716}
2717
2718mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2719 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2720 getWaitOperandsSegments(), getHasWaitDevnum(),
2721 deviceType);
2722}
2723
2724LogicalResult acc::KernelsOp::verify() {
2726 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2727 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2728 return failure();
2729
2731 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2732 getWaitOperandsDeviceTypeAttr(), "wait")))
2733 return failure();
2734
2735 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2736 getNumWorkersDeviceTypeAttr(),
2737 "num_workers")))
2738 return failure();
2739
2740 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2741 getVectorLengthDeviceTypeAttr(),
2742 "vector_length")))
2743 return failure();
2744
2746 getAsyncOperandsDeviceTypeAttr(),
2747 "async")))
2748 return failure();
2749
2751 return failure();
2752
2753 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2754}
2755
2756void acc::KernelsOp::addPrivatization(MLIRContext *context,
2757 mlir::acc::PrivateOp op,
2758 mlir::acc::PrivateRecipeOp recipe) {
2759 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2760 getPrivateOperandsMutable().append(op.getResult());
2761}
2762
2763void acc::KernelsOp::addFirstPrivatization(
2764 MLIRContext *context, mlir::acc::FirstprivateOp op,
2765 mlir::acc::FirstprivateRecipeOp recipe) {
2766 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2767 getFirstprivateOperandsMutable().append(op.getResult());
2768}
2769
2770void acc::KernelsOp::addReduction(MLIRContext *context,
2771 mlir::acc::ReductionOp op,
2772 mlir::acc::ReductionRecipeOp recipe) {
2773 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2774 getReductionOperandsMutable().append(op.getResult());
2775}
2776
2777void acc::KernelsOp::addNumWorkersOperand(
2778 MLIRContext *context, mlir::Value newValue,
2779 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2780 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2781 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2782 getNumWorkersMutable()));
2783}
2784
2785void acc::KernelsOp::addVectorLengthOperand(
2786 MLIRContext *context, mlir::Value newValue,
2787 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2788 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2789 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2790 getVectorLengthMutable()));
2791}
2792void acc::KernelsOp::addAsyncOnly(
2793 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2794 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2795 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2796}
2797
2798void acc::KernelsOp::addAsyncOperand(
2799 MLIRContext *context, mlir::Value newValue,
2800 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2801 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2802 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2803 getAsyncOperandsMutable()));
2804}
2805
2806void acc::KernelsOp::addNumGangsOperands(
2807 MLIRContext *context, mlir::ValueRange newValues,
2808 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2810 if (getNumGangsSegmentsAttr())
2811 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2812
2813 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2814 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2815 getNumGangsMutable(), segments));
2816
2817 setNumGangsSegments(segments);
2818}
2819
2820void acc::KernelsOp::addWaitOnly(
2821 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2822 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2823 effectiveDeviceTypes));
2824}
2825void acc::KernelsOp::addWaitOperands(
2826 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2827 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2828
2830 if (getWaitOperandsSegments())
2831 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2832
2833 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2834 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2835 getWaitOperandsMutable(), segments));
2836 setWaitOperandsSegments(segments);
2837
2839 if (getHasWaitDevnumAttr())
2840 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2841 hasDevnums.insert(
2842 hasDevnums.end(),
2843 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2844 mlir::BoolAttr::get(context, hasDevnum));
2845 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2846}
2847
2848//===----------------------------------------------------------------------===//
2849// HostDataOp
2850//===----------------------------------------------------------------------===//
2851
2852LogicalResult acc::HostDataOp::verify() {
2853 if (getDataClauseOperands().empty())
2854 return emitError("at least one operand must appear on the host_data "
2855 "operation");
2856
2857 for (mlir::Value operand : getDataClauseOperands())
2858 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2859 return emitError("expect data entry operation as defining op");
2860 return success();
2861}
2862
2863void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2864 MLIRContext *context) {
2865 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2866}
2867
2868//===----------------------------------------------------------------------===//
2869// KernelEnvironmentOp
2870//===----------------------------------------------------------------------===//
2871
2872void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2873 RewritePatternSet &results, MLIRContext *context) {
2874 results.add<RemoveEmptyKernelEnvironment>(context);
2875}
2876
2877//===----------------------------------------------------------------------===//
2878// LoopOp
2879//===----------------------------------------------------------------------===//
2880
2881static ParseResult parseGangValue(
2882 OpAsmParser &parser, llvm::StringRef keyword,
2885 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2886 bool &needCommaBetweenValues, bool &newValue) {
2887 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2888 if (parser.parseEqual())
2889 return failure();
2890 if (parser.parseOperand(operands.emplace_back()) ||
2891 parser.parseColonType(types.emplace_back()))
2892 return failure();
2893 attributes.push_back(gangArgType);
2894 needCommaBetweenValues = true;
2895 newValue = true;
2896 }
2897 return success();
2898}
2899
2900static ParseResult parseGangClause(
2901 OpAsmParser &parser,
2903 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2904 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2905 mlir::ArrayAttr &gangOnlyDeviceType) {
2906 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2907 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2908 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2910 bool needCommaBetweenValues = false;
2911 bool needCommaBeforeOperands = false;
2912
2913 if (failed(parser.parseOptionalLParen())) {
2914 // Gang only keyword
2915 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2916 parser.getContext(), mlir::acc::DeviceType::None));
2917 gangOnlyDeviceType =
2918 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2919 return success();
2920 }
2921
2922 // Parse gang only attributes
2923 if (succeeded(parser.parseOptionalLSquare())) {
2924 // Parse gang only attributes
2925 if (failed(parser.parseCommaSeparatedList([&]() {
2926 if (parser.parseAttribute(
2927 gangOnlyDeviceTypeAttributes.emplace_back()))
2928 return failure();
2929 return success();
2930 })))
2931 return failure();
2932 if (parser.parseRSquare())
2933 return failure();
2934 needCommaBeforeOperands = true;
2935 }
2936
2937 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2938 mlir::acc::GangArgType::Num);
2939 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2940 mlir::acc::GangArgType::Dim);
2941 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2942 parser.getContext(), mlir::acc::GangArgType::Static);
2943
2944 do {
2945 if (needCommaBeforeOperands) {
2946 needCommaBeforeOperands = false;
2947 continue;
2948 }
2949
2950 if (failed(parser.parseLBrace()))
2951 return failure();
2952
2953 int32_t crtOperandsSize = gangOperands.size();
2954 while (true) {
2955 bool newValue = false;
2956 bool needValue = false;
2957 if (needCommaBetweenValues) {
2958 if (succeeded(parser.parseOptionalComma()))
2959 needValue = true; // expect a new value after comma.
2960 else
2961 break;
2962 }
2963
2964 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2965 gangOperands, gangOperandsType,
2966 gangArgTypeAttributes, argNum,
2967 needCommaBetweenValues, newValue)))
2968 return failure();
2969 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2970 gangOperands, gangOperandsType,
2971 gangArgTypeAttributes, argDim,
2972 needCommaBetweenValues, newValue)))
2973 return failure();
2974 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2975 gangOperands, gangOperandsType,
2976 gangArgTypeAttributes, argStatic,
2977 needCommaBetweenValues, newValue)))
2978 return failure();
2979
2980 if (!newValue && needValue) {
2981 parser.emitError(parser.getCurrentLocation(),
2982 "new value expected after comma");
2983 return failure();
2984 }
2985
2986 if (!newValue)
2987 break;
2988 }
2989
2990 if (gangOperands.empty())
2991 return parser.emitError(
2992 parser.getCurrentLocation(),
2993 "expect at least one of num, dim or static values");
2994
2995 if (failed(parser.parseRBrace()))
2996 return failure();
2997
2998 if (succeeded(parser.parseOptionalLSquare())) {
2999 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3000 parser.parseRSquare())
3001 return failure();
3002 } else {
3003 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3004 parser.getContext(), mlir::acc::DeviceType::None));
3005 }
3006
3007 seg.push_back(gangOperands.size() - crtOperandsSize);
3008
3009 } while (succeeded(parser.parseOptionalComma()));
3010
3011 if (failed(parser.parseRParen()))
3012 return failure();
3013
3014 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3015 gangArgTypeAttributes.end());
3016 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3017 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3018
3020 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3021 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3022
3023 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3024 return success();
3025}
3026
3028 mlir::OperandRange operands, mlir::TypeRange types,
3029 std::optional<mlir::ArrayAttr> gangArgTypes,
3030 std::optional<mlir::ArrayAttr> deviceTypes,
3031 std::optional<mlir::DenseI32ArrayAttr> segments,
3032 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3033
3034 if (operands.begin() == operands.end() &&
3035 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3036 return;
3037 }
3038
3039 p << "(";
3040
3041 printDeviceTypes(p, gangOnlyDeviceTypes);
3042
3043 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3044 hasDeviceTypeValues(deviceTypes))
3045 p << ", ";
3046
3047 if (hasDeviceTypeValues(deviceTypes)) {
3048 unsigned opIdx = 0;
3049 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3050 p << "{";
3051 llvm::interleaveComma(
3052 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3053 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3054 (*gangArgTypes)[opIdx]);
3055 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3056 p << LoopOp::getGangNumKeyword();
3057 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3058 p << LoopOp::getGangDimKeyword();
3059 else if (gangArgTypeAttr.getValue() ==
3060 mlir::acc::GangArgType::Static)
3061 p << LoopOp::getGangStaticKeyword();
3062 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3063 ++opIdx;
3064 });
3065 p << "}";
3066 printSingleDeviceType(p, it.value());
3067 });
3068 }
3069 p << ")";
3070}
3071
3073 std::optional<mlir::ArrayAttr> segments,
3074 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3075 if (!segments)
3076 return false;
3077 for (auto attr : *segments) {
3078 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3079 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3080 return true;
3081 }
3082 return false;
3083}
3084
3085/// Check for duplicates in the DeviceType array attribute.
3086/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3087static std::optional<mlir::acc::DeviceType>
3088checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3089 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3090 if (!deviceTypes)
3091 return std::nullopt;
3092 for (auto attr : deviceTypes) {
3093 auto deviceTypeAttr =
3094 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3095 if (!deviceTypeAttr)
3096 return mlir::acc::DeviceType::None;
3097 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3098 return deviceTypeAttr.getValue();
3099 }
3100 return std::nullopt;
3101}
3102
3103LogicalResult acc::LoopOp::verify() {
3104 if (getUpperbound().size() != getStep().size())
3105 return emitError() << "number of upperbounds expected to be the same as "
3106 "number of steps";
3107
3108 if (getUpperbound().size() != getLowerbound().size())
3109 return emitError() << "number of upperbounds expected to be the same as "
3110 "number of lowerbounds";
3111
3112 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3113 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3114 return emitError() << "inclusiveUpperbound size is expected to be the same"
3115 << " as upperbound size";
3116
3117 // Check collapse
3118 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3119 return emitOpError() << "collapse device_type attr must be define when"
3120 << " collapse attr is present";
3121
3122 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3123 getCollapseAttr().getValue().size() !=
3124 getCollapseDeviceTypeAttr().getValue().size())
3125 return emitOpError() << "collapse attribute count must match collapse"
3126 << " device_type count";
3127 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3128 return emitOpError() << "duplicate device_type `"
3129 << acc::stringifyDeviceType(*duplicateDeviceType)
3130 << "` found in collapseDeviceType attribute";
3131
3132 // Check gang
3133 if (!getGangOperands().empty()) {
3134 if (!getGangOperandsArgType())
3135 return emitOpError() << "gangOperandsArgType attribute must be defined"
3136 << " when gang operands are present";
3137
3138 if (getGangOperands().size() !=
3139 getGangOperandsArgTypeAttr().getValue().size())
3140 return emitOpError() << "gangOperandsArgType attribute count must match"
3141 << " gangOperands count";
3142 }
3143 if (getGangAttr()) {
3144 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3145 return emitOpError() << "duplicate device_type `"
3146 << acc::stringifyDeviceType(*duplicateDeviceType)
3147 << "` found in gang attribute";
3148 }
3149
3151 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3152 getGangOperandsDeviceTypeAttr(), "gang")))
3153 return failure();
3154
3155 // Check worker
3156 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3157 return emitOpError() << "duplicate device_type `"
3158 << acc::stringifyDeviceType(*duplicateDeviceType)
3159 << "` found in worker attribute";
3160 if (auto duplicateDeviceType =
3161 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3162 return emitOpError() << "duplicate device_type `"
3163 << acc::stringifyDeviceType(*duplicateDeviceType)
3164 << "` found in workerNumOperandsDeviceType attribute";
3165 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3166 getWorkerNumOperandsDeviceTypeAttr(),
3167 "worker")))
3168 return failure();
3169
3170 // Check vector
3171 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3172 return emitOpError() << "duplicate device_type `"
3173 << acc::stringifyDeviceType(*duplicateDeviceType)
3174 << "` found in vector attribute";
3175 if (auto duplicateDeviceType =
3176 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3177 return emitOpError() << "duplicate device_type `"
3178 << acc::stringifyDeviceType(*duplicateDeviceType)
3179 << "` found in vectorOperandsDeviceType attribute";
3180 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3181 getVectorOperandsDeviceTypeAttr(),
3182 "vector")))
3183 return failure();
3184
3186 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3187 getTileOperandsDeviceTypeAttr(), "tile")))
3188 return failure();
3189
3190 // auto, independent and seq attribute are mutually exclusive.
3191 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3192 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3193 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3194 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3195 return emitError() << "only one of auto, independent, seq can be present "
3196 "at the same time";
3197 }
3198
3199 // Check that at least one of auto, independent, or seq is present
3200 // for the device-independent default clauses.
3201 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3202 return attr.getValue() == mlir::acc::DeviceType::None;
3203 };
3204 bool hasDefaultSeq =
3205 getSeqAttr()
3206 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3207 hasDeviceNone)
3208 : false;
3209 bool hasDefaultIndependent =
3210 getIndependentAttr()
3211 ? llvm::any_of(
3212 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3213 hasDeviceNone)
3214 : false;
3215 bool hasDefaultAuto =
3216 getAuto_Attr()
3217 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3218 hasDeviceNone)
3219 : false;
3220 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3221 return emitError()
3222 << "at least one of auto, independent, seq must be present";
3223 }
3224
3225 // Gang, worker and vector are incompatible with seq.
3226 if (getSeqAttr()) {
3227 for (auto attr : getSeqAttr()) {
3228 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3229 if (hasVector(deviceTypeAttr.getValue()) ||
3230 getVectorValue(deviceTypeAttr.getValue()) ||
3231 hasWorker(deviceTypeAttr.getValue()) ||
3232 getWorkerValue(deviceTypeAttr.getValue()) ||
3233 hasGang(deviceTypeAttr.getValue()) ||
3234 getGangValue(mlir::acc::GangArgType::Num,
3235 deviceTypeAttr.getValue()) ||
3236 getGangValue(mlir::acc::GangArgType::Dim,
3237 deviceTypeAttr.getValue()) ||
3238 getGangValue(mlir::acc::GangArgType::Static,
3239 deviceTypeAttr.getValue()))
3240 return emitError() << "gang, worker or vector cannot appear with seq";
3241 }
3242 }
3243
3244 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3245 mlir::acc::PrivateRecipeOp>(
3246 *this, getPrivateOperands(), "private")))
3247 return failure();
3248
3249 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3250 mlir::acc::FirstprivateRecipeOp>(
3251 *this, getFirstprivateOperands(), "firstprivate")))
3252 return failure();
3253
3254 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3255 mlir::acc::ReductionRecipeOp>(
3256 *this, getReductionOperands(), "reduction")))
3257 return failure();
3258
3259 if (getCombined().has_value() &&
3260 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3261 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3262 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3263 return emitError("unexpected combined constructs attribute");
3264 }
3265
3266 // Check non-empty body().
3267 if (getRegion().empty())
3268 return emitError("expected non-empty body.");
3269
3270 if (getUnstructured()) {
3271 if (!isContainerLike())
3272 return emitError(
3273 "unstructured acc.loop must not have induction variables");
3274 } else if (isContainerLike()) {
3275 // When it is container-like - it is expected to hold a loop-like operation.
3276 // Obtain the maximum collapse count - we use this to check that there
3277 // are enough loops contained.
3278 uint64_t collapseCount = getCollapseValue().value_or(1);
3279 if (getCollapseAttr()) {
3280 for (auto collapseEntry : getCollapseAttr()) {
3281 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3282 if (intAttr.getValue().getZExtValue() > collapseCount)
3283 collapseCount = intAttr.getValue().getZExtValue();
3284 }
3285 }
3286
3287 // We want to check that we find enough loop-like operations inside.
3288 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3289 // level.
3290 mlir::Operation *expectedParent = this->getOperation();
3291 bool foundSibling = false;
3292 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3293 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3294 // This effectively checks that we are not looking at a sibling loop.
3295 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3296 expectedParent) {
3297 foundSibling = true;
3299 }
3300
3301 collapseCount--;
3302 expectedParent = op;
3303 }
3304 // We found enough contained loops.
3305 if (collapseCount == 0)
3308 });
3309
3310 if (foundSibling)
3311 return emitError("found sibling loops inside container-like acc.loop");
3312 if (collapseCount != 0)
3313 return emitError("failed to find enough loop-like operations inside "
3314 "container-like acc.loop");
3315 }
3316
3317 return success();
3318}
3319
3320unsigned LoopOp::getNumDataOperands() {
3321 return getReductionOperands().size() + getPrivateOperands().size() +
3322 getFirstprivateOperands().size();
3323}
3324
3325Value LoopOp::getDataOperand(unsigned i) {
3326 unsigned numOptional =
3327 getLowerbound().size() + getUpperbound().size() + getStep().size();
3328 numOptional += getGangOperands().size();
3329 numOptional += getVectorOperands().size();
3330 numOptional += getWorkerNumOperands().size();
3331 numOptional += getTileOperands().size();
3332 numOptional += getCacheOperands().size();
3333 return getOperand(numOptional + i);
3334}
3335
3336bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3337
3338bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3339 return hasDeviceType(getAuto_(), deviceType);
3340}
3341
3342bool LoopOp::hasIndependent() {
3343 return hasIndependent(mlir::acc::DeviceType::None);
3344}
3345
3346bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3347 return hasDeviceType(getIndependent(), deviceType);
3348}
3349
3350bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3351
3352bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3353 return hasDeviceType(getSeq(), deviceType);
3354}
3355
3356mlir::Value LoopOp::getVectorValue() {
3357 return getVectorValue(mlir::acc::DeviceType::None);
3358}
3359
3360mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3361 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3362 getVectorOperands(), deviceType);
3363}
3364
3365bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3366
3367bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3368 return hasDeviceType(getVector(), deviceType);
3369}
3370
3371mlir::Value LoopOp::getWorkerValue() {
3372 return getWorkerValue(mlir::acc::DeviceType::None);
3373}
3374
3375mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3376 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3377 getWorkerNumOperands(), deviceType);
3378}
3379
3380bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3381
3382bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3383 return hasDeviceType(getWorker(), deviceType);
3384}
3385
3386mlir::Operation::operand_range LoopOp::getTileValues() {
3387 return getTileValues(mlir::acc::DeviceType::None);
3388}
3389
3391LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3392 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3393 getTileOperandsSegments(), deviceType);
3394}
3395
3396std::optional<int64_t> LoopOp::getCollapseValue() {
3397 return getCollapseValue(mlir::acc::DeviceType::None);
3398}
3399
3400std::optional<int64_t>
3401LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3402 if (!getCollapseAttr())
3403 return std::nullopt;
3404 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3405 auto intAttr =
3406 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3407 return intAttr.getValue().getZExtValue();
3408 }
3409 return std::nullopt;
3410}
3411
3412mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3413 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3414}
3415
3416mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3417 mlir::acc::DeviceType deviceType) {
3418 if (getGangOperands().empty())
3419 return {};
3420 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3421 int32_t nbOperandsBefore = 0;
3422 for (unsigned i = 0; i < *pos; ++i)
3423 nbOperandsBefore += (*getGangOperandsSegments())[i];
3425 getGangOperands()
3426 .drop_front(nbOperandsBefore)
3427 .take_front((*getGangOperandsSegments())[*pos]);
3428
3429 int32_t argTypeIdx = nbOperandsBefore;
3430 for (auto value : values) {
3431 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3432 (*getGangOperandsArgType())[argTypeIdx]);
3433 if (gangArgTypeAttr.getValue() == gangArgType)
3434 return value;
3435 ++argTypeIdx;
3436 }
3437 }
3438 return {};
3439}
3440
3441bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3442
3443bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3444 return hasDeviceType(getGang(), deviceType);
3445}
3446
3447llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3448 return {&getRegion()};
3449}
3450
3451/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3452/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3453/// `(` ssa-id-and-type-list `)`
3454/// region
3455ParseResult
3458 SmallVectorImpl<Type> &lowerboundType,
3460 SmallVectorImpl<Type> &upperboundType,
3462 SmallVectorImpl<Type> &stepType) {
3463
3465 if (succeeded(
3466 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3467 if (parser.parseLParen() ||
3468 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3469 /*allowType=*/true) ||
3470 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3471 parser.parseOperandList(lowerbound, inductionVars.size(),
3473 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3474 parser.parseKeyword("to") || parser.parseLParen() ||
3475 parser.parseOperandList(upperbound, inductionVars.size(),
3477 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3478 parser.parseKeyword("step") || parser.parseLParen() ||
3479 parser.parseOperandList(step, inductionVars.size(),
3481 parser.parseColonTypeList(stepType) || parser.parseRParen())
3482 return failure();
3483 }
3484 return parser.parseRegion(region, inductionVars);
3485}
3486
3488 ValueRange lowerbound, TypeRange lowerboundType,
3489 ValueRange upperbound, TypeRange upperboundType,
3490 ValueRange steps, TypeRange stepType) {
3491 ValueRange regionArgs = region.front().getArguments();
3492 if (!regionArgs.empty()) {
3493 p << acc::LoopOp::getControlKeyword() << "(";
3494 llvm::interleaveComma(regionArgs, p,
3495 [&p](Value v) { p << v << " : " << v.getType(); });
3496 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3497 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3498 << " : " << stepType << ") ";
3499 }
3500 p.printRegion(region, /*printEntryBlockArgs=*/false);
3501}
3502
3503void acc::LoopOp::addSeq(MLIRContext *context,
3504 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3505 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3506 effectiveDeviceTypes));
3507}
3508
3509void acc::LoopOp::addIndependent(
3510 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3511 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3512 context, getIndependentAttr(), effectiveDeviceTypes));
3513}
3514
3515void acc::LoopOp::addAuto(MLIRContext *context,
3516 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3517 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3518 effectiveDeviceTypes));
3519}
3520
3521void acc::LoopOp::setCollapseForDeviceTypes(
3522 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3523 llvm::APInt value) {
3526
3527 assert((getCollapseAttr() == nullptr) ==
3528 (getCollapseDeviceTypeAttr() == nullptr));
3529 assert(value.getBitWidth() == 64);
3530
3531 if (getCollapseAttr()) {
3532 for (const auto &existing :
3533 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3534 newValues.push_back(std::get<0>(existing));
3535 newDeviceTypes.push_back(std::get<1>(existing));
3536 }
3537 }
3538
3539 if (effectiveDeviceTypes.empty()) {
3540 // If the effective device-types list is empty, this is before there are any
3541 // being applied by device_type, so this should be added as a 'none'.
3542 newValues.push_back(
3543 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3544 newDeviceTypes.push_back(
3545 acc::DeviceTypeAttr::get(context, DeviceType::None));
3546 } else {
3547 for (DeviceType dt : effectiveDeviceTypes) {
3548 newValues.push_back(
3549 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3550 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3551 }
3552 }
3553
3554 setCollapseAttr(ArrayAttr::get(context, newValues));
3555 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3556}
3557
3558void acc::LoopOp::setTileForDeviceTypes(
3559 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3560 ValueRange values) {
3562 if (getTileOperandsSegments())
3563 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3564
3565 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3566 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3567 getTileOperandsMutable(), segments));
3568
3569 setTileOperandsSegments(segments);
3570}
3571
3572void acc::LoopOp::addVectorOperand(
3573 MLIRContext *context, mlir::Value newValue,
3574 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3575 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3576 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3577 newValue, getVectorOperandsMutable()));
3578}
3579
3580void acc::LoopOp::addEmptyVector(
3581 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3582 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3583 effectiveDeviceTypes));
3584}
3585
3586void acc::LoopOp::addWorkerNumOperand(
3587 MLIRContext *context, mlir::Value newValue,
3588 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3589 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3590 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3591 newValue, getWorkerNumOperandsMutable()));
3592}
3593
3594void acc::LoopOp::addEmptyWorker(
3595 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3596 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3597 effectiveDeviceTypes));
3598}
3599
3600void acc::LoopOp::addEmptyGang(
3601 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3602 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3603 effectiveDeviceTypes));
3604}
3605
3606bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3607 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3608 return attr.getValue() == dt;
3609 };
3610 auto testFromArr = [=](ArrayAttr arr) -> bool {
3611 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3612 };
3613
3614 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3615 return true;
3616 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3617 return true;
3618 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3619 return true;
3620
3621 return false;
3622}
3623
3624bool acc::LoopOp::hasDefaultGangWorkerVector() {
3625 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3626 hasGang() || getGangValue(GangArgType::Num) ||
3627 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3628}
3629
3630acc::LoopParMode
3631acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3632 if (hasSeq(deviceType))
3633 return LoopParMode::loop_seq;
3634 if (hasAuto(deviceType))
3635 return LoopParMode::loop_auto;
3636 if (hasIndependent(deviceType))
3637 return LoopParMode::loop_independent;
3638 if (hasSeq())
3639 return LoopParMode::loop_seq;
3640 if (hasAuto())
3641 return LoopParMode::loop_auto;
3642 assert(hasIndependent() &&
3643 "loop must have default auto, seq, or independent");
3644 return LoopParMode::loop_independent;
3645}
3646
3647void acc::LoopOp::addGangOperands(
3648 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3651 if (std::optional<ArrayRef<int32_t>> existingSegments =
3652 getGangOperandsSegments())
3653 llvm::copy(*existingSegments, std::back_inserter(segments));
3654
3655 unsigned beforeCount = segments.size();
3656
3657 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3658 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3659 getGangOperandsMutable(), segments));
3660
3661 setGangOperandsSegments(segments);
3662
3663 // This is a bit of extra work to make sure we update the 'types' correctly by
3664 // adding to the types collection the correct number of times. We could
3665 // potentially add something similar to the
3666 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3667 // excessive for a one-off case.
3668 unsigned numAdded = segments.size() - beforeCount;
3669
3670 if (numAdded > 0) {
3672 if (getGangOperandsArgTypeAttr())
3673 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3674
3675 for (auto i : llvm::index_range(0u, numAdded)) {
3676 llvm::transform(argTypes, std::back_inserter(gangTypes),
3677 [=](mlir::acc::GangArgType gangTy) {
3678 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3679 });
3680 (void)i;
3681 }
3682
3683 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3684 }
3685}
3686
3687void acc::LoopOp::addPrivatization(MLIRContext *context,
3688 mlir::acc::PrivateOp op,
3689 mlir::acc::PrivateRecipeOp recipe) {
3690 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3691 getPrivateOperandsMutable().append(op.getResult());
3692}
3693
3694void acc::LoopOp::addFirstPrivatization(
3695 MLIRContext *context, mlir::acc::FirstprivateOp op,
3696 mlir::acc::FirstprivateRecipeOp recipe) {
3697 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3698 getFirstprivateOperandsMutable().append(op.getResult());
3699}
3700
3701void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3702 mlir::acc::ReductionRecipeOp recipe) {
3703 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3704 getReductionOperandsMutable().append(op.getResult());
3705}
3706
3707//===----------------------------------------------------------------------===//
3708// DataOp
3709//===----------------------------------------------------------------------===//
3710
3711LogicalResult acc::DataOp::verify() {
3712 // 2.6.5. Data Construct restriction
3713 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3714 // attach, or default clause must appear on a data construct.
3715 if (getOperands().empty() && !getDefaultAttr())
3716 return emitError("at least one operand or the default attribute "
3717 "must appear on the data operation");
3718
3719 for (mlir::Value operand : getDataClauseOperands())
3720 if (isa<BlockArgument>(operand) ||
3721 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3722 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3723 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3724 operand.getDefiningOp()))
3725 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3726 "as defining op");
3727
3729 return failure();
3730
3731 return success();
3732}
3733
3734unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3735
3736Value DataOp::getDataOperand(unsigned i) {
3737 unsigned numOptional = getIfCond() ? 1 : 0;
3738 numOptional += getAsyncOperands().size() ? 1 : 0;
3739 numOptional += getWaitOperands().size();
3740 return getOperand(numOptional + i);
3741}
3742
3743bool acc::DataOp::hasAsyncOnly() {
3744 return hasAsyncOnly(mlir::acc::DeviceType::None);
3745}
3746
3747bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3748 return hasDeviceType(getAsyncOnly(), deviceType);
3749}
3750
3751mlir::Value DataOp::getAsyncValue() {
3752 return getAsyncValue(mlir::acc::DeviceType::None);
3753}
3754
3755mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3757 getAsyncOperands(), deviceType);
3758}
3759
3760bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3761
3762bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3763 return hasDeviceType(getWaitOnly(), deviceType);
3764}
3765
3766mlir::Operation::operand_range DataOp::getWaitValues() {
3767 return getWaitValues(mlir::acc::DeviceType::None);
3768}
3769
3771DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3773 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3774 getHasWaitDevnum(), deviceType);
3775}
3776
3777mlir::Value DataOp::getWaitDevnum() {
3778 return getWaitDevnum(mlir::acc::DeviceType::None);
3779}
3780
3781mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3782 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3783 getWaitOperandsSegments(), getHasWaitDevnum(),
3784 deviceType);
3785}
3786
3787void acc::DataOp::addAsyncOnly(
3788 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3789 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3790 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3791}
3792
3793void acc::DataOp::addAsyncOperand(
3794 MLIRContext *context, mlir::Value newValue,
3795 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3796 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3797 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3798 getAsyncOperandsMutable()));
3799}
3800
3801void acc::DataOp::addWaitOnly(MLIRContext *context,
3802 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3803 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3804 effectiveDeviceTypes));
3805}
3806
3807void acc::DataOp::addWaitOperands(
3808 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3809 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3810
3812 if (getWaitOperandsSegments())
3813 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3814
3815 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3816 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3817 getWaitOperandsMutable(), segments));
3818 setWaitOperandsSegments(segments);
3819
3821 if (getHasWaitDevnumAttr())
3822 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3823 hasDevnums.insert(
3824 hasDevnums.end(),
3825 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3826 mlir::BoolAttr::get(context, hasDevnum));
3827 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3828}
3829
3830//===----------------------------------------------------------------------===//
3831// ExitDataOp
3832//===----------------------------------------------------------------------===//
3833
3834LogicalResult acc::ExitDataOp::verify() {
3835 // 2.6.6. Data Exit Directive restriction
3836 // At least one copyout, delete, or detach clause must appear on an exit data
3837 // directive.
3838 if (getDataClauseOperands().empty())
3839 return emitError("at least one operand must be present in dataOperands on "
3840 "the exit data operation");
3841
3842 // The async attribute represent the async clause without value. Therefore the
3843 // attribute and operand cannot appear at the same time.
3844 if (getAsyncOperand() && getAsync())
3845 return emitError("async attribute cannot appear with asyncOperand");
3846
3847 // The wait attribute represent the wait clause without values. Therefore the
3848 // attribute and operands cannot appear at the same time.
3849 if (!getWaitOperands().empty() && getWait())
3850 return emitError("wait attribute cannot appear with waitOperands");
3851
3852 if (getWaitDevnum() && getWaitOperands().empty())
3853 return emitError("wait_devnum cannot appear without waitOperands");
3854
3855 return success();
3856}
3857
3858unsigned ExitDataOp::getNumDataOperands() {
3859 return getDataClauseOperands().size();
3860}
3861
3862Value ExitDataOp::getDataOperand(unsigned i) {
3863 unsigned numOptional = getIfCond() ? 1 : 0;
3864 numOptional += getAsyncOperand() ? 1 : 0;
3865 numOptional += getWaitDevnum() ? 1 : 0;
3866 return getOperand(getWaitOperands().size() + numOptional + i);
3867}
3868
3869void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3870 MLIRContext *context) {
3871 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3872}
3873
3874void ExitDataOp::addAsyncOnly(MLIRContext *context,
3875 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3876 assert(effectiveDeviceTypes.empty());
3877 assert(!getAsyncAttr());
3878 assert(!getAsyncOperand());
3879
3880 setAsyncAttr(mlir::UnitAttr::get(context));
3881}
3882
3883void ExitDataOp::addAsyncOperand(
3884 MLIRContext *context, mlir::Value newValue,
3885 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3886 assert(effectiveDeviceTypes.empty());
3887 assert(!getAsyncAttr());
3888 assert(!getAsyncOperand());
3889
3890 getAsyncOperandMutable().append(newValue);
3891}
3892
3893void ExitDataOp::addWaitOnly(MLIRContext *context,
3894 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getWaitAttr());
3897 assert(getWaitOperands().empty());
3898 assert(!getWaitDevnum());
3899
3900 setWaitAttr(mlir::UnitAttr::get(context));
3901}
3902
3903void ExitDataOp::addWaitOperands(
3904 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3905 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3906 assert(effectiveDeviceTypes.empty());
3907 assert(!getWaitAttr());
3908 assert(getWaitOperands().empty());
3909 assert(!getWaitDevnum());
3910
3911 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3912 // operands list.
3913 if (hasDevnum) {
3914 getWaitDevnumMutable().append(newValues.front());
3915 newValues = newValues.drop_front();
3916 }
3917
3918 getWaitOperandsMutable().append(newValues);
3919}
3920
3921//===----------------------------------------------------------------------===//
3922// EnterDataOp
3923//===----------------------------------------------------------------------===//
3924
3925LogicalResult acc::EnterDataOp::verify() {
3926 // 2.6.6. Data Enter Directive restriction
3927 // At least one copyin, create, or attach clause must appear on an enter data
3928 // directive.
3929 if (getDataClauseOperands().empty())
3930 return emitError("at least one operand must be present in dataOperands on "
3931 "the enter data operation");
3932
3933 // The async attribute represent the async clause without value. Therefore the
3934 // attribute and operand cannot appear at the same time.
3935 if (getAsyncOperand() && getAsync())
3936 return emitError("async attribute cannot appear with asyncOperand");
3937
3938 // The wait attribute represent the wait clause without values. Therefore the
3939 // attribute and operands cannot appear at the same time.
3940 if (!getWaitOperands().empty() && getWait())
3941 return emitError("wait attribute cannot appear with waitOperands");
3942
3943 if (getWaitDevnum() && getWaitOperands().empty())
3944 return emitError("wait_devnum cannot appear without waitOperands");
3945
3946 for (mlir::Value operand : getDataClauseOperands())
3947 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3948 operand.getDefiningOp()))
3949 return emitError("expect data entry operation as defining op");
3950
3951 return success();
3952}
3953
3954unsigned EnterDataOp::getNumDataOperands() {
3955 return getDataClauseOperands().size();
3956}
3957
3958Value EnterDataOp::getDataOperand(unsigned i) {
3959 unsigned numOptional = getIfCond() ? 1 : 0;
3960 numOptional += getAsyncOperand() ? 1 : 0;
3961 numOptional += getWaitDevnum() ? 1 : 0;
3962 return getOperand(getWaitOperands().size() + numOptional + i);
3963}
3964
3965void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3966 MLIRContext *context) {
3967 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3968}
3969
3970void EnterDataOp::addAsyncOnly(
3971 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3972 assert(effectiveDeviceTypes.empty());
3973 assert(!getAsyncAttr());
3974 assert(!getAsyncOperand());
3975
3976 setAsyncAttr(mlir::UnitAttr::get(context));
3977}
3978
3979void EnterDataOp::addAsyncOperand(
3980 MLIRContext *context, mlir::Value newValue,
3981 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3982 assert(effectiveDeviceTypes.empty());
3983 assert(!getAsyncAttr());
3984 assert(!getAsyncOperand());
3985
3986 getAsyncOperandMutable().append(newValue);
3987}
3988
3989void EnterDataOp::addWaitOnly(MLIRContext *context,
3990 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3991 assert(effectiveDeviceTypes.empty());
3992 assert(!getWaitAttr());
3993 assert(getWaitOperands().empty());
3994 assert(!getWaitDevnum());
3995
3996 setWaitAttr(mlir::UnitAttr::get(context));
3997}
3998
3999void EnterDataOp::addWaitOperands(
4000 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4001 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4002 assert(effectiveDeviceTypes.empty());
4003 assert(!getWaitAttr());
4004 assert(getWaitOperands().empty());
4005 assert(!getWaitDevnum());
4006
4007 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4008 // operands list.
4009 if (hasDevnum) {
4010 getWaitDevnumMutable().append(newValues.front());
4011 newValues = newValues.drop_front();
4012 }
4013
4014 getWaitOperandsMutable().append(newValues);
4015}
4016
4017//===----------------------------------------------------------------------===//
4018// AtomicReadOp
4019//===----------------------------------------------------------------------===//
4020
4021LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4022
4023//===----------------------------------------------------------------------===//
4024// AtomicWriteOp
4025//===----------------------------------------------------------------------===//
4026
4027LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4028
4029//===----------------------------------------------------------------------===//
4030// AtomicUpdateOp
4031//===----------------------------------------------------------------------===//
4032
4033LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4034 PatternRewriter &rewriter) {
4035 if (op.isNoOp()) {
4036 rewriter.eraseOp(op);
4037 return success();
4038 }
4039
4040 if (Value writeVal = op.getWriteOpVal()) {
4041 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4042 op.getIfCond());
4043 return success();
4044 }
4045
4046 return failure();
4047}
4048
4049LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4050
4051LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4052
4053//===----------------------------------------------------------------------===//
4054// AtomicCaptureOp
4055//===----------------------------------------------------------------------===//
4056
4057AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4058 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4059 return op;
4060 return dyn_cast<AtomicReadOp>(getSecondOp());
4061}
4062
4063AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4064 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4065 return op;
4066 return dyn_cast<AtomicWriteOp>(getSecondOp());
4067}
4068
4069AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4070 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4071 return op;
4072 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4073}
4074
4075LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4076
4077//===----------------------------------------------------------------------===//
4078// DeclareEnterOp
4079//===----------------------------------------------------------------------===//
4080
4081template <typename Op>
4082static LogicalResult
4084 bool requireAtLeastOneOperand = true) {
4085 if (operands.empty() && requireAtLeastOneOperand)
4086 return emitError(
4087 op->getLoc(),
4088 "at least one operand must appear on the declare operation");
4089
4090 for (mlir::Value operand : operands) {
4091 if (isa<BlockArgument>(operand) ||
4092 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4093 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4094 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4095 operand.getDefiningOp()))
4096 return op.emitError(
4097 "expect valid declare data entry operation or acc.getdeviceptr "
4098 "as defining op");
4099
4100 mlir::Value var{getVar(operand.getDefiningOp())};
4101 assert(var && "declare operands can only be data entry operations which "
4102 "must have var");
4103 (void)var;
4104 std::optional<mlir::acc::DataClause> dataClauseOptional{
4105 getDataClause(operand.getDefiningOp())};
4106 assert(dataClauseOptional.has_value() &&
4107 "declare operands can only be data entry operations which must have "
4108 "dataClause");
4109 (void)dataClauseOptional;
4110 }
4111
4112 return success();
4113}
4114
4115LogicalResult acc::DeclareEnterOp::verify() {
4116 return checkDeclareOperands(*this, this->getDataClauseOperands());
4117}
4118
4119//===----------------------------------------------------------------------===//
4120// DeclareExitOp
4121//===----------------------------------------------------------------------===//
4122
4123LogicalResult acc::DeclareExitOp::verify() {
4124 if (getToken())
4125 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4126 /*requireAtLeastOneOperand=*/false);
4127 return checkDeclareOperands(*this, this->getDataClauseOperands());
4128}
4129
4130//===----------------------------------------------------------------------===//
4131// DeclareOp
4132//===----------------------------------------------------------------------===//
4133
4134LogicalResult acc::DeclareOp::verify() {
4135 return checkDeclareOperands(*this, this->getDataClauseOperands());
4136}
4137
4138//===----------------------------------------------------------------------===//
4139// RoutineOp
4140//===----------------------------------------------------------------------===//
4141
4142static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4143 acc::DeviceType dtype) {
4144 unsigned parallelism = 0;
4145 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4146 parallelism += op.hasWorker(dtype) ? 1 : 0;
4147 parallelism += op.hasVector(dtype) ? 1 : 0;
4148 parallelism += op.hasSeq(dtype) ? 1 : 0;
4149 return parallelism;
4150}
4151
4152LogicalResult acc::RoutineOp::verify() {
4153 unsigned baseParallelism =
4154 getParallelismForDeviceType(*this, acc::DeviceType::None);
4155
4156 if (baseParallelism > 1)
4157 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4158 "be present at the same time";
4159
4160 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4161 ++dtypeInt) {
4162 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4163 if (dtype == acc::DeviceType::None)
4164 continue;
4165 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4166
4167 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4168 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4169 "be present at the same time for device_type `"
4170 << acc::stringifyDeviceType(dtype) << "`";
4171 }
4172
4173 return success();
4174}
4175
4176static ParseResult parseBindName(OpAsmParser &parser,
4177 mlir::ArrayAttr &bindIdName,
4178 mlir::ArrayAttr &bindStrName,
4179 mlir::ArrayAttr &deviceIdTypes,
4180 mlir::ArrayAttr &deviceStrTypes) {
4181 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4182 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4183 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4184 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4185
4186 if (failed(parser.parseCommaSeparatedList([&]() {
4187 mlir::Attribute newAttr;
4188 bool isSymbolRefAttr;
4189 auto parseResult = parser.parseAttribute(newAttr);
4190 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4191 bindIdNameAttrs.push_back(symbolRefAttr);
4192 isSymbolRefAttr = true;
4193 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4194 bindStrNameAttrs.push_back(stringAttr);
4195 isSymbolRefAttr = false;
4196 }
4197 if (parseResult)
4198 return failure();
4199 if (failed(parser.parseOptionalLSquare())) {
4200 if (isSymbolRefAttr) {
4201 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4202 parser.getContext(), mlir::acc::DeviceType::None));
4203 } else {
4204 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4205 parser.getContext(), mlir::acc::DeviceType::None));
4206 }
4207 } else {
4208 if (isSymbolRefAttr) {
4209 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4210 parser.parseRSquare())
4211 return failure();
4212 } else {
4213 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4214 parser.parseRSquare())
4215 return failure();
4216 }
4217 }
4218 return success();
4219 })))
4220 return failure();
4221
4222 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4223 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4224 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4225 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4226
4227 return success();
4228}
4229
4231 std::optional<mlir::ArrayAttr> bindIdName,
4232 std::optional<mlir::ArrayAttr> bindStrName,
4233 std::optional<mlir::ArrayAttr> deviceIdTypes,
4234 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4235 // Create combined vectors for all bind names and device types
4238
4239 // Append bindIdName and deviceIdTypes
4240 if (hasDeviceTypeValues(deviceIdTypes)) {
4241 allBindNames.append(bindIdName->begin(), bindIdName->end());
4242 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4243 }
4244
4245 // Append bindStrName and deviceStrTypes
4246 if (hasDeviceTypeValues(deviceStrTypes)) {
4247 allBindNames.append(bindStrName->begin(), bindStrName->end());
4248 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4249 }
4250
4251 // Print the combined sequence
4252 if (!allBindNames.empty())
4253 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4254 [&](const auto &pair) {
4255 p << std::get<0>(pair);
4256 printSingleDeviceType(p, std::get<1>(pair));
4257 });
4258}
4259
4260static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4261 mlir::ArrayAttr &gang,
4262 mlir::ArrayAttr &gangDim,
4263 mlir::ArrayAttr &gangDimDeviceTypes) {
4264
4265 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4266 gangDimDeviceTypeAttrs;
4267 bool needCommaBeforeOperands = false;
4268
4269 // Gang keyword only
4270 if (failed(parser.parseOptionalLParen())) {
4271 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4272 parser.getContext(), mlir::acc::DeviceType::None));
4273 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4274 return success();
4275 }
4276
4277 // Parse keyword only attributes
4278 if (succeeded(parser.parseOptionalLSquare())) {
4279 if (failed(parser.parseCommaSeparatedList([&]() {
4280 if (parser.parseAttribute(gangAttrs.emplace_back()))
4281 return failure();
4282 return success();
4283 })))
4284 return failure();
4285 if (parser.parseRSquare())
4286 return failure();
4287 needCommaBeforeOperands = true;
4288 }
4289
4290 if (needCommaBeforeOperands && failed(parser.parseComma()))
4291 return failure();
4292
4293 if (failed(parser.parseCommaSeparatedList([&]() {
4294 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4295 parser.parseColon() ||
4296 parser.parseAttribute(gangDimAttrs.emplace_back()))
4297 return failure();
4298 if (succeeded(parser.parseOptionalLSquare())) {
4299 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4300 parser.parseRSquare())
4301 return failure();
4302 } else {
4303 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4304 parser.getContext(), mlir::acc::DeviceType::None));
4305 }
4306 return success();
4307 })))
4308 return failure();
4309
4310 if (failed(parser.parseRParen()))
4311 return failure();
4312
4313 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4314 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4315 gangDimDeviceTypes =
4316 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4317
4318 return success();
4319}
4320
4322 std::optional<mlir::ArrayAttr> gang,
4323 std::optional<mlir::ArrayAttr> gangDim,
4324 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4325
4326 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4327 gang->size() == 1) {
4328 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4329 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4330 return;
4331 }
4332
4333 p << "(";
4334
4335 printDeviceTypes(p, gang);
4336
4337 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4338 p << ", ";
4339
4340 if (hasDeviceTypeValues(gangDimDeviceTypes))
4341 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4342 [&](const auto &pair) {
4343 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4344 p << std::get<0>(pair);
4345 printSingleDeviceType(p, std::get<1>(pair));
4346 });
4347
4348 p << ")";
4349}
4350
4351static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4352 mlir::ArrayAttr &deviceTypes) {
4354 // Keyword only
4355 if (failed(parser.parseOptionalLParen())) {
4356 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4357 parser.getContext(), mlir::acc::DeviceType::None));
4358 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4359 return success();
4360 }
4361
4362 // Parse device type attributes
4363 if (succeeded(parser.parseOptionalLSquare())) {
4364 if (failed(parser.parseCommaSeparatedList([&]() {
4365 if (parser.parseAttribute(attributes.emplace_back()))
4366 return failure();
4367 return success();
4368 })))
4369 return failure();
4370 if (parser.parseRSquare() || parser.parseRParen())
4371 return failure();
4372 }
4373 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4374 return success();
4375}
4376
4377static void
4379 std::optional<mlir::ArrayAttr> deviceTypes) {
4380
4381 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4382 auto deviceTypeAttr =
4383 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4384 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4385 return;
4386 }
4387
4388 if (!hasDeviceTypeValues(deviceTypes))
4389 return;
4390
4391 p << "([";
4392 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4393 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4394 p << dTypeAttr;
4395 });
4396 p << "])";
4397}
4398
4399bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4400
4401bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4402 return hasDeviceType(getWorker(), deviceType);
4403}
4404
4405bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4406
4407bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4408 return hasDeviceType(getVector(), deviceType);
4409}
4410
4411bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4412
4413bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4414 return hasDeviceType(getSeq(), deviceType);
4415}
4416
4417std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4418RoutineOp::getBindNameValue() {
4419 return getBindNameValue(mlir::acc::DeviceType::None);
4420}
4421
4422std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4423RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4424 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4425 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4426 return std::nullopt;
4427 }
4428
4429 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4430 auto attr = (*getBindIdName())[*pos];
4431 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4432 assert(symbolRefAttr && "expected SymbolRef");
4433 return symbolRefAttr;
4434 }
4435
4436 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4437 auto attr = (*getBindStrName())[*pos];
4438 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4439 assert(stringAttr && "expected String");
4440 return stringAttr;
4441 }
4442
4443 return std::nullopt;
4444}
4445
4446bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4447
4448bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4449 return hasDeviceType(getGang(), deviceType);
4450}
4451
4452std::optional<int64_t> RoutineOp::getGangDimValue() {
4453 return getGangDimValue(mlir::acc::DeviceType::None);
4454}
4455
4456std::optional<int64_t>
4457RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4458 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4459 return std::nullopt;
4460 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4461 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4462 return intAttr.getInt();
4463 }
4464 return std::nullopt;
4465}
4466
4467void RoutineOp::addSeq(MLIRContext *context,
4468 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4469 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4470 effectiveDeviceTypes));
4471}
4472
4473void RoutineOp::addVector(MLIRContext *context,
4474 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4475 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4476 effectiveDeviceTypes));
4477}
4478
4479void RoutineOp::addWorker(MLIRContext *context,
4480 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4481 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4482 effectiveDeviceTypes));
4483}
4484
4485void RoutineOp::addGang(MLIRContext *context,
4486 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4487 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4488 effectiveDeviceTypes));
4489}
4490
4491void RoutineOp::addGang(MLIRContext *context,
4492 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4493 uint64_t val) {
4496
4497 if (getGangDimAttr())
4498 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4499 if (getGangDimDeviceTypeAttr())
4500 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4501
4502 assert(dimValues.size() == deviceTypes.size());
4503
4504 if (effectiveDeviceTypes.empty()) {
4505 dimValues.push_back(
4506 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4507 deviceTypes.push_back(
4508 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4509 } else {
4510 for (DeviceType dt : effectiveDeviceTypes) {
4511 dimValues.push_back(
4512 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4513 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4514 }
4515 }
4516 assert(dimValues.size() == deviceTypes.size());
4517
4518 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4519 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4520}
4521
4522void RoutineOp::addBindStrName(MLIRContext *context,
4523 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4524 mlir::StringAttr val) {
4525 unsigned before = getBindStrNameDeviceTypeAttr()
4526 ? getBindStrNameDeviceTypeAttr().size()
4527 : 0;
4528
4529 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4530 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4531 unsigned after = getBindStrNameDeviceTypeAttr().size();
4532
4534 if (getBindStrNameAttr())
4535 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4536 for (unsigned i = 0; i < after - before; ++i)
4537 vals.push_back(val);
4538
4539 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4540}
4541
4542void RoutineOp::addBindIDName(MLIRContext *context,
4543 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4544 mlir::SymbolRefAttr val) {
4545 unsigned before =
4546 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4547
4548 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4549 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4550 unsigned after = getBindIdNameDeviceTypeAttr().size();
4551
4553 if (getBindIdNameAttr())
4554 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4555 for (unsigned i = 0; i < after - before; ++i)
4556 vals.push_back(val);
4557
4558 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4559}
4560
4561//===----------------------------------------------------------------------===//
4562// InitOp
4563//===----------------------------------------------------------------------===//
4564
4565LogicalResult acc::InitOp::verify() {
4566 Operation *currOp = *this;
4567 while ((currOp = currOp->getParentOp()))
4568 if (isComputeOperation(currOp))
4569 return emitOpError("cannot be nested in a compute operation");
4570 return success();
4571}
4572
4573void acc::InitOp::addDeviceType(MLIRContext *context,
4574 mlir::acc::DeviceType deviceType) {
4576 if (getDeviceTypesAttr())
4577 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4578
4579 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4580 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4581}
4582
4583//===----------------------------------------------------------------------===//
4584// ShutdownOp
4585//===----------------------------------------------------------------------===//
4586
4587LogicalResult acc::ShutdownOp::verify() {
4588 Operation *currOp = *this;
4589 while ((currOp = currOp->getParentOp()))
4590 if (isComputeOperation(currOp))
4591 return emitOpError("cannot be nested in a compute operation");
4592 return success();
4593}
4594
4595void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4596 mlir::acc::DeviceType deviceType) {
4598 if (getDeviceTypesAttr())
4599 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4600
4601 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4602 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4603}
4604
4605//===----------------------------------------------------------------------===//
4606// SetOp
4607//===----------------------------------------------------------------------===//
4608
4609LogicalResult acc::SetOp::verify() {
4610 Operation *currOp = *this;
4611 while ((currOp = currOp->getParentOp()))
4612 if (isComputeOperation(currOp))
4613 return emitOpError("cannot be nested in a compute operation");
4614 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4615 return emitOpError("at least one default_async, device_num, or device_type "
4616 "operand must appear");
4617 return success();
4618}
4619
4620//===----------------------------------------------------------------------===//
4621// UpdateOp
4622//===----------------------------------------------------------------------===//
4623
4624LogicalResult acc::UpdateOp::verify() {
4625 // At least one of host or device should have a value.
4626 if (getDataClauseOperands().empty())
4627 return emitError("at least one value must be present in dataOperands");
4628
4630 getAsyncOperandsDeviceTypeAttr(),
4631 "async")))
4632 return failure();
4633
4635 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4636 getWaitOperandsDeviceTypeAttr(), "wait")))
4637 return failure();
4638
4640 return failure();
4641
4642 for (mlir::Value operand : getDataClauseOperands())
4643 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4644 operand.getDefiningOp()))
4645 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4646 "as defining op");
4647
4648 return success();
4649}
4650
4651unsigned UpdateOp::getNumDataOperands() {
4652 return getDataClauseOperands().size();
4653}
4654
4655Value UpdateOp::getDataOperand(unsigned i) {
4656 unsigned numOptional = getAsyncOperands().size();
4657 numOptional += getIfCond() ? 1 : 0;
4658 return getOperand(getWaitOperands().size() + numOptional + i);
4659}
4660
4661void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4662 MLIRContext *context) {
4663 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4664}
4665
4666bool UpdateOp::hasAsyncOnly() {
4667 return hasAsyncOnly(mlir::acc::DeviceType::None);
4668}
4669
4670bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4671 return hasDeviceType(getAsyncOnly(), deviceType);
4672}
4673
4674mlir::Value UpdateOp::getAsyncValue() {
4675 return getAsyncValue(mlir::acc::DeviceType::None);
4676}
4677
4678mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4680 return {};
4681
4682 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4683 return getAsyncOperands()[*pos];
4684
4685 return {};
4686}
4687
4688bool UpdateOp::hasWaitOnly() {
4689 return hasWaitOnly(mlir::acc::DeviceType::None);
4690}
4691
4692bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4693 return hasDeviceType(getWaitOnly(), deviceType);
4694}
4695
4696mlir::Operation::operand_range UpdateOp::getWaitValues() {
4697 return getWaitValues(mlir::acc::DeviceType::None);
4698}
4699
4701UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4703 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4704 getHasWaitDevnum(), deviceType);
4705}
4706
4707mlir::Value UpdateOp::getWaitDevnum() {
4708 return getWaitDevnum(mlir::acc::DeviceType::None);
4709}
4710
4711mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4712 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4713 getWaitOperandsSegments(), getHasWaitDevnum(),
4714 deviceType);
4715}
4716
4717void UpdateOp::addAsyncOnly(MLIRContext *context,
4718 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4719 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4720 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4721}
4722
4723void UpdateOp::addAsyncOperand(
4724 MLIRContext *context, mlir::Value newValue,
4725 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4726 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4727 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4728 getAsyncOperandsMutable()));
4729}
4730
4731void UpdateOp::addWaitOnly(MLIRContext *context,
4732 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4733 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4734 effectiveDeviceTypes));
4735}
4736
4737void UpdateOp::addWaitOperands(
4738 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4739 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4740
4742 if (getWaitOperandsSegments())
4743 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4744
4745 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4746 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4747 getWaitOperandsMutable(), segments));
4748 setWaitOperandsSegments(segments);
4749
4751 if (getHasWaitDevnumAttr())
4752 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4753 hasDevnums.insert(
4754 hasDevnums.end(),
4755 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4756 mlir::BoolAttr::get(context, hasDevnum));
4757 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4758}
4759
4760//===----------------------------------------------------------------------===//
4761// WaitOp
4762//===----------------------------------------------------------------------===//
4763
4764LogicalResult acc::WaitOp::verify() {
4765 // The async attribute represent the async clause without value. Therefore the
4766 // attribute and operand cannot appear at the same time.
4767 if (getAsyncOperand() && getAsync())
4768 return emitError("async attribute cannot appear with asyncOperand");
4769
4770 if (getWaitDevnum() && getWaitOperands().empty())
4771 return emitError("wait_devnum cannot appear without waitOperands");
4772
4773 return success();
4774}
4775
4776#define GET_OP_CLASSES
4777#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4778
4779#define GET_ATTRDEF_CLASSES
4780#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4781
4782#define GET_TYPEDEF_CLASSES
4783#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4784
4785//===----------------------------------------------------------------------===//
4786// acc dialect utilities
4787//===----------------------------------------------------------------------===//
4788
4791 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4793 accDataClauseOp)
4794 .Case<ACC_DATA_ENTRY_OPS>(
4795 [&](auto entry) { return entry.getVarPtr(); })
4796 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4797 [&](auto exit) { return exit.getVarPtr(); })
4798 .Default([&](mlir::Operation *) {
4800 })};
4801 return varPtr;
4802}
4803
4805 auto varPtr{
4807 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4808 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4809 return varPtr;
4810}
4811
4813 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4814 .Case<ACC_DATA_ENTRY_OPS>(
4815 [&](auto entry) { return entry.getVarType(); })
4816 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4817 [&](auto exit) { return exit.getVarType(); })
4818 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4819 return varType;
4820}
4821
4824 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4826 accDataClauseOp)
4827 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4828 [&](auto dataClause) { return dataClause.getAccPtr(); })
4829 .Default([&](mlir::Operation *) {
4831 })};
4832 return accPtr;
4833}
4834
4836 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4838 [&](auto dataClause) { return dataClause.getAccVar(); })
4839 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4840 return accPtr;
4841}
4842
4844 auto varPtrPtr{
4846 .Case<ACC_DATA_ENTRY_OPS>(
4847 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4848 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4849 return varPtrPtr;
4850}
4851
4856 accDataClauseOp)
4857 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4859 dataClause.getBounds().begin(), dataClause.getBounds().end());
4860 })
4861 .Default([&](mlir::Operation *) {
4863 })};
4864 return bounds;
4865}
4866
4870 accDataClauseOp)
4871 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4873 dataClause.getAsyncOperands().begin(),
4874 dataClause.getAsyncOperands().end());
4875 })
4876 .Default([&](mlir::Operation *) {
4878 });
4879}
4880
4881mlir::ArrayAttr
4884 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4885 return dataClause.getAsyncOperandsDeviceTypeAttr();
4886 })
4887 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4888}
4889
4890mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4893 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4894 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4895}
4896
4897std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4898 auto name{
4900 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4901 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4902 return {};
4903 })};
4904 return name;
4905}
4906
4907std::optional<mlir::acc::DataClause>
4909 auto dataClause{
4911 accDataEntryOp)
4912 .Case<ACC_DATA_ENTRY_OPS>(
4913 [&](auto entry) { return entry.getDataClause(); })
4914 .Default([&](mlir::Operation *) { return std::nullopt; })};
4915 return dataClause;
4916}
4917
4919 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4920 .Case<ACC_DATA_ENTRY_OPS>(
4921 [&](auto entry) { return entry.getImplicit(); })
4922 .Default([&](mlir::Operation *) { return false; })};
4923 return implicit;
4924}
4925
4927 auto dataOperands{
4930 [&](auto entry) { return entry.getDataClauseOperands(); })
4931 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4932 return dataOperands;
4933}
4934
4937 auto dataOperands{
4940 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4941 .Default([&](mlir::Operation *) { return nullptr; })};
4942 return dataOperands;
4943}
4944
4945mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
4946 auto recipe{
4948 .Case<ACC_DATA_ENTRY_OPS>(
4949 [&](auto entry) { return entry.getRecipeAttr(); })
4950 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
4951 return recipe;
4952}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4321
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1144
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3072
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1690
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4176
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:745
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1158
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:525
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2186
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:738
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:678
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:509
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:647
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2197
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2102
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:451
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4378
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2881
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2437
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3088
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4083
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:585
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2391
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:469
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:567
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:601
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3456
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1645
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2233
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1772
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:593
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:656
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:480
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:493
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1972
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:400
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:632
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3487
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4351
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4260
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2085
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2260
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2376
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2039
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2352
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:719
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2900
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1425
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2421
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2016
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:611
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1659
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2333
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:455
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3027
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2271
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:690
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:545
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1700
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4142
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2022
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2457
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4230
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:53
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4835
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4804
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4823
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:4908
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:4936
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4853
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:4926
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4897
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:4918
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:4945
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4868
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:166
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4843
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:203
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4890
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4812
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4790
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4882
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.