MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
23#include "mlir/IR/SymbolTable.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/PostOrderIterator.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/STLForwardCompat.h"
30#include "llvm/ADT/SmallString.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/ADT/bit.h"
35#include "llvm/Support/InterleavedRange.h"
36#include <cstddef>
37#include <iterator>
38#include <optional>
39#include <variant>
40
41#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45
46using namespace mlir;
47using namespace mlir::omp;
48
51 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52}
53
56 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57}
58
61 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
62}
63
64namespace {
65struct MemRefPointerLikeModel
66 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
67 MemRefType> {
68 Type getElementType(Type pointer) const {
69 return llvm::cast<MemRefType>(pointer).getElementType();
70 }
71};
72
73struct LLVMPointerPointerLikeModel
74 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
76 Type getElementType(Type pointer) const { return Type(); }
77};
78} // namespace
79
80/// Generate a name of a canonical loop nest of the format
81/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
82/// argument index of an operation that has multiple regions, if the operation
83/// has multiple regions.
84/// `_s<idx>` identifies the position of an operation within a region, where
85/// only operations that may potentially contain loops ("container operations"
86/// i.e. have region arguments) are counted. Again, it is omitted if there is
87/// only one such operation in a region. If there are canonical loops nested
88/// inside each other, also may also use the format `_d<num>` where <num> is the
89/// nesting depth of the loop.
90///
91/// The generated name is a best-effort to make canonical loop unique within an
92/// SSA namespace. This also means that regions with IsolatedFromAbove property
93/// do not consider any parents or siblings.
94static std::string generateLoopNestingName(StringRef prefix,
95 CanonicalLoopOp op) {
96 struct Component {
97 /// If true, this component describes a region operand of an operation (the
98 /// operand's owner) If false, this component describes an operation located
99 /// in a parent region
100 bool isRegionArgOfOp;
101 bool skip = false;
102 bool isUnique = false;
103
104 size_t idx;
105 Operation *op;
106 Region *parentRegion;
107 size_t loopDepth;
108
109 Operation *&getOwnerOp() {
110 assert(isRegionArgOfOp && "Must describe a region operand");
111 return op;
112 }
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp && "Must describe a region operand");
115 return idx;
116 }
117
118 Operation *&getContainerOp() {
119 assert(!isRegionArgOfOp && "Must describe a operation of a region");
120 return op;
121 }
122 size_t &getOpPos() {
123 assert(!isRegionArgOfOp && "Must describe a operation of a region");
124 return idx;
125 }
126 bool isLoopOp() const {
127 assert(!isRegionArgOfOp && "Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
129 }
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp && "Must describe a operation of a region");
132 return parentRegion;
133 }
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp && "Must describe a operation of a region");
136 return loopDepth;
137 }
138
139 void skipIf(bool v = true) { skip = skip || v; }
140 };
141
142 // List of ancestors, from inner to outer.
143 // Alternates between
144 // * region argument of an operation
145 // * operation within a region
146 SmallVector<Component> components;
147
148 // Gather a list of parent regions and operations, and the position within
149 // their parent
150 Operation *o = op.getOperation();
151 while (o) {
152 // Operation within a region
153 Region *r = o->getParentRegion();
154 if (!r)
155 break;
156
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
158 size_t idx = 0;
159 bool found = false;
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp = true;
162 for (Block *b : traversal) {
163 for (Operation &op : *b) {
164 if (&op == o && !found) {
165 sequentialIdx = idx;
166 found = true;
167 }
168 if (op.getNumRegions()) {
169 idx += 1;
170 if (idx > 1)
171 isOnlyContainerOp = false;
172 }
173 if (found && !isOnlyContainerOp)
174 break;
175 }
176 }
177
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp = false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
184
185 Operation *parent = r->getParentOp();
186
187 // Region argument of an operation
188 Component &regionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp = true;
190 regionArgOfOperation.isUnique = true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
193
194 // The IsolatedFromAbove trait of the parent operation implies that each
195 // individual region argument has its own separate namespace, so no
196 // ambiguity.
197 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
198 break;
199
200 // Component only needed if operation has multiple region operands. Region
201 // arguments may be optional, but we currently do not consider this.
202 if (parent->getRegions().size() > 1) {
203 auto getRegionIndex = [](Operation *o, Region *r) {
204 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
205 if (&region == r)
206 return idx;
207 }
208 llvm_unreachable("Region not child of its parent operation");
209 };
210 regionArgOfOperation.isUnique = false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
212 }
213
214 // next parent
215 o = parent;
216 }
217
218 // Determine whether a region-argument component is not needed
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
221
222 // Find runs of nested loops and determine each loop's depth in the loop nest
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
225 if (c.skip)
226 continue;
227
228 // non-skipped multi-argument operands interrupt the loop nest
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
231 continue;
232 }
233
234 // Multiple loops in a region means each of them is the outermost loop of a
235 // new loop nest
236 if (!c.isUnique)
237 numSurroundingLoops = 0;
238
239 c.getLoopDepth() = numSurroundingLoops;
240
241 // Next loop is surrounded by one more loop
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
244 }
245
246 // In loop nests, skip all but the innermost loop that contains the depth
247 // number
248 bool isLoopNest = false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
251 continue;
252
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
254 // Innermost loop of a loop nest of at least two loops
255 isLoopNest = true;
256 } else if (isLoopNest) {
257 // Non-innermost loop of a loop nest
258 c.skipIf(c.isUnique);
259
260 // If there is no surrounding loop left, this must have been the outermost
261 // loop; leave loop-nest mode for the next iteration
262 if (c.getLoopDepth() == 0)
263 isLoopNest = false;
264 }
265 }
266
267 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
268 // we mark them as skipped only after computing loop nests)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
272
273 // Components can be skipped if they are already disambiguated by their parent
274 // (or does not have a parent)
275 bool newRegion = true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
278
279 // non-skipped components disambiguate unique children
280 if (!c.skip)
281 newRegion = true;
282
283 // ...except canonical loops that need a suffix for each nest
284 if (!c.isRegionArgOfOp && c.getContainerOp())
285 newRegion = false;
286 }
287
288 // Compile the nesting name string
289 SmallString<64> Name{prefix};
290 llvm::raw_svector_ostream NameOS(Name);
291 for (auto &c : llvm::reverse(components)) {
292 if (c.skip)
293 continue;
294
295 if (c.isRegionArgOfOp)
296 NameOS << "_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS << "_d" << c.getLoopDepth();
299 else
300 NameOS << "_s" << c.getOpPos();
301 }
302
303 return NameOS.str().str();
304}
305
306void OpenMPDialect::initialize() {
307 addOperations<
308#define GET_OP_LIST
309#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
310 >();
311 addAttributes<
312#define GET_ATTRDEF_LIST
313#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
314 >();
315 addTypes<
316#define GET_TYPEDEF_LIST
317#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
318 >();
319
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
321
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
324 *getContext());
325
326 // Attach default offload module interface to module op to access
327 // offload functionality through
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
329 *getContext());
330
331 // Attach default declare target interfaces to operations which can be marked
332 // as declare target (Global Operations and Functions/Subroutines in dialects
333 // that Fortran (or other languages that lower to MLIR) translates too
334 mlir::LLVM::GlobalOp::attachInterface<
336 *getContext());
337 mlir::LLVM::LLVMFuncOp::attachInterface<
339 *getContext());
340 mlir::func::FuncOp::attachInterface<
342}
343
344//===----------------------------------------------------------------------===//
345// Parser and printer for Allocate Clause
346//===----------------------------------------------------------------------===//
347
348/// Parse an allocate clause with allocators and a list of operands with types.
349///
350/// allocate-operand-list :: = allocate-operand |
351/// allocator-operand `,` allocate-operand-list
352/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
353/// ssa-id-and-type ::= ssa-id `:` type
354static ParseResult parseAllocateAndAllocator(
355 OpAsmParser &parser,
357 SmallVectorImpl<Type> &allocateTypes,
359 SmallVectorImpl<Type> &allocatorTypes) {
360
361 return parser.parseCommaSeparatedList([&]() {
363 Type type;
364 if (parser.parseOperand(operand) || parser.parseColonType(type))
365 return failure();
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
368 if (parser.parseArrow())
369 return failure();
370 if (parser.parseOperand(operand) || parser.parseColonType(type))
371 return failure();
372
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
375 return success();
376 });
377}
378
379/// Print allocate clause
381 OperandRange allocateVars,
382 TypeRange allocateTypes,
383 OperandRange allocatorVars,
384 TypeRange allocatorTypes) {
385 for (unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
387 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
388 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
389 }
390}
391
392//===----------------------------------------------------------------------===//
393// Parser and printer for a clause attribute (StringEnumAttr)
394//===----------------------------------------------------------------------===//
395
396template <typename ClauseAttr>
397static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
398 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
399 StringRef enumStr;
400 SMLoc loc = parser.getCurrentLocation();
401 if (parser.parseKeyword(&enumStr))
402 return failure();
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404 attr = ClauseAttr::get(parser.getContext(), *enumValue);
405 return success();
406 }
407 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
408}
409
410template <typename ClauseAttr>
411static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
412 p << stringifyEnum(attr.getValue());
413}
414
415//===----------------------------------------------------------------------===//
416// Parser and printer for Linear Clause
417//===----------------------------------------------------------------------===//
418
419/// linear ::= `linear` `(` linear-list `)`
420/// linear-list := linear-val | linear-val linear-list
421/// linear-val := ssa-id-and-type `=` ssa-id-and-type
422static ParseResult parseLinearClause(
423 OpAsmParser &parser,
425 SmallVectorImpl<Type> &linearTypes,
427 return parser.parseCommaSeparatedList([&]() {
429 Type type;
431 if (parser.parseOperand(var) || parser.parseEqual() ||
432 parser.parseOperand(stepVar) || parser.parseColonType(type))
433 return failure();
434
435 linearVars.push_back(var);
436 linearTypes.push_back(type);
437 linearStepVars.push_back(stepVar);
438 return success();
439 });
440}
441
442/// Print Linear Clause
444 ValueRange linearVars, TypeRange linearTypes,
445 ValueRange linearStepVars) {
446 size_t linearVarsSize = linearVars.size();
447 for (unsigned i = 0; i < linearVarsSize; ++i) {
448 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
449 p << linearVars[i];
450 if (linearStepVars.size() > i)
451 p << " = " << linearStepVars[i];
452 p << " : " << linearVars[i].getType() << separator;
453 }
454}
455
456//===----------------------------------------------------------------------===//
457// Verifier for Nontemporal Clause
458//===----------------------------------------------------------------------===//
459
460static LogicalResult verifyNontemporalClause(Operation *op,
461 OperandRange nontemporalVars) {
462
463 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
464 DenseSet<Value> nontemporalItems;
465 for (const auto &it : nontemporalVars)
466 if (!nontemporalItems.insert(it).second)
467 return op->emitOpError() << "nontemporal variable used more than once";
468
469 return success();
470}
471
472//===----------------------------------------------------------------------===//
473// Parser, verifier and printer for Aligned Clause
474//===----------------------------------------------------------------------===//
475static LogicalResult verifyAlignedClause(Operation *op,
476 std::optional<ArrayAttr> alignments,
477 OperandRange alignedVars) {
478 // Check if number of alignment values equals to number of aligned variables
479 if (!alignedVars.empty()) {
480 if (!alignments || alignments->size() != alignedVars.size())
481 return op->emitOpError()
482 << "expected as many alignment values as aligned variables";
483 } else {
484 if (alignments)
485 return op->emitOpError() << "unexpected alignment values attribute";
486 return success();
487 }
488
489 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
490 DenseSet<Value> alignedItems;
491 for (auto it : alignedVars)
492 if (!alignedItems.insert(it).second)
493 return op->emitOpError() << "aligned variable used more than once";
494
495 if (!alignments)
496 return success();
497
498 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
499 for (unsigned i = 0; i < (*alignments).size(); ++i) {
500 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
501 if (intAttr.getValue().sle(0))
502 return op->emitOpError() << "alignment should be greater than 0";
503 } else {
504 return op->emitOpError() << "expected integer alignment";
505 }
506 }
507
508 return success();
509}
510
511/// aligned ::= `aligned` `(` aligned-list `)`
512/// aligned-list := aligned-val | aligned-val aligned-list
513/// aligned-val := ssa-id-and-type `->` alignment
514static ParseResult
517 SmallVectorImpl<Type> &alignedTypes,
518 ArrayAttr &alignmentsAttr) {
519 SmallVector<Attribute> alignmentVec;
520 if (failed(parser.parseCommaSeparatedList([&]() {
521 if (parser.parseOperand(alignedVars.emplace_back()) ||
522 parser.parseColonType(alignedTypes.emplace_back()) ||
523 parser.parseArrow() ||
524 parser.parseAttribute(alignmentVec.emplace_back())) {
525 return failure();
526 }
527 return success();
528 })))
529 return failure();
530 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
531 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
532 return success();
533}
534
535/// Print Aligned Clause
537 ValueRange alignedVars, TypeRange alignedTypes,
538 std::optional<ArrayAttr> alignments) {
539 for (unsigned i = 0; i < alignedVars.size(); ++i) {
540 if (i != 0)
541 p << ", ";
542 p << alignedVars[i] << " : " << alignedVars[i].getType();
543 p << " -> " << (*alignments)[i];
544 }
545}
546
547//===----------------------------------------------------------------------===//
548// Parser, printer and verifier for Schedule Clause
549//===----------------------------------------------------------------------===//
550
551static ParseResult
553 SmallVectorImpl<SmallString<12>> &modifiers) {
554 if (modifiers.size() > 2)
555 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
556 for (const auto &mod : modifiers) {
557 // Translate the string. If it has no value, then it was not a valid
558 // modifier!
559 auto symbol = symbolizeScheduleModifier(mod);
560 if (!symbol)
561 return parser.emitError(parser.getNameLoc())
562 << " unknown modifier type: " << mod;
563 }
564
565 // If we have one modifier that is "simd", then stick a "none" modiifer in
566 // index 0.
567 if (modifiers.size() == 1) {
568 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
569 modifiers.push_back(modifiers[0]);
570 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
571 }
572 } else if (modifiers.size() == 2) {
573 // If there are two modifier:
574 // First modifier should not be simd, second one should be simd
575 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
576 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
577 return parser.emitError(parser.getNameLoc())
578 << " incorrect modifier order";
579 }
580 return success();
581}
582
583/// schedule ::= `schedule` `(` sched-list `)`
584/// sched-list ::= sched-val | sched-val sched-list |
585/// sched-val `,` sched-modifier
586/// sched-val ::= sched-with-chunk | sched-wo-chunk
587/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
588/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
589/// sched-wo-chunk ::= `auto` | `runtime`
590/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
591/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
592static ParseResult
593parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
594 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
595 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
596 Type &chunkType) {
597 StringRef keyword;
598 if (parser.parseKeyword(&keyword))
599 return failure();
600 std::optional<mlir::omp::ClauseScheduleKind> schedule =
601 symbolizeClauseScheduleKind(keyword);
602 if (!schedule)
603 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
604
605 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
606 switch (*schedule) {
607 case ClauseScheduleKind::Static:
608 case ClauseScheduleKind::Dynamic:
609 case ClauseScheduleKind::Guided:
610 if (succeeded(parser.parseOptionalEqual())) {
611 chunkSize = OpAsmParser::UnresolvedOperand{};
612 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
613 return failure();
614 } else {
615 chunkSize = std::nullopt;
616 }
617 break;
618 case ClauseScheduleKind::Auto:
619 case ClauseScheduleKind::Runtime:
620 case ClauseScheduleKind::Distribute:
621 chunkSize = std::nullopt;
622 }
623
624 // If there is a comma, we have one or more modifiers..
626 while (succeeded(parser.parseOptionalComma())) {
627 StringRef mod;
628 if (parser.parseKeyword(&mod))
629 return failure();
630 modifiers.push_back(mod);
631 }
632
633 if (verifyScheduleModifiers(parser, modifiers))
634 return failure();
635
636 if (!modifiers.empty()) {
637 SMLoc loc = parser.getCurrentLocation();
638 if (std::optional<ScheduleModifier> mod =
639 symbolizeScheduleModifier(modifiers[0])) {
640 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
641 } else {
642 return parser.emitError(loc, "invalid schedule modifier");
643 }
644 // Only SIMD attribute is allowed here!
645 if (modifiers.size() > 1) {
646 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
647 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
648 }
649 }
650
651 return success();
652}
653
654/// Print schedule clause
656 ClauseScheduleKindAttr scheduleKind,
657 ScheduleModifierAttr scheduleMod,
658 UnitAttr scheduleSimd, Value scheduleChunk,
659 Type scheduleChunkType) {
660 p << stringifyClauseScheduleKind(scheduleKind.getValue());
661 if (scheduleChunk)
662 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
663 if (scheduleMod)
664 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
665 if (scheduleSimd)
666 p << ", simd";
667}
668
669//===----------------------------------------------------------------------===//
670// Parser and printer for Order Clause
671//===----------------------------------------------------------------------===//
672
673// order ::= `order` `(` [order-modifier ':'] concurrent `)`
674// order-modifier ::= reproducible | unconstrained
675static ParseResult parseOrderClause(OpAsmParser &parser,
676 ClauseOrderKindAttr &order,
677 OrderModifierAttr &orderMod) {
678 StringRef enumStr;
679 SMLoc loc = parser.getCurrentLocation();
680 if (parser.parseKeyword(&enumStr))
681 return failure();
682 if (std::optional<OrderModifier> enumValue =
683 symbolizeOrderModifier(enumStr)) {
684 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
685 if (parser.parseOptionalColon())
686 return failure();
687 loc = parser.getCurrentLocation();
688 if (parser.parseKeyword(&enumStr))
689 return failure();
690 }
691 if (std::optional<ClauseOrderKind> enumValue =
692 symbolizeClauseOrderKind(enumStr)) {
693 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
694 return success();
695 }
696 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
697}
698
700 ClauseOrderKindAttr order,
701 OrderModifierAttr orderMod) {
702 if (orderMod)
703 p << stringifyOrderModifier(orderMod.getValue()) << ":";
704 if (order)
705 p << stringifyClauseOrderKind(order.getValue());
706}
707
708template <typename ClauseTypeAttr, typename ClauseType>
709static ParseResult
710parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
711 std::optional<OpAsmParser::UnresolvedOperand> &operand,
712 Type &operandType,
713 std::optional<ClauseType> (*symbolizeClause)(StringRef),
714 StringRef clauseName) {
715 StringRef enumStr;
716 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
717 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
718 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
719 if (parser.parseComma())
720 return failure();
721 } else {
722 return parser.emitError(parser.getCurrentLocation())
723 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
724 ;
725 }
726 }
727
729 if (succeeded(parser.parseOperand(var))) {
730 operand = var;
731 } else {
732 return parser.emitError(parser.getCurrentLocation())
733 << "expected " << clauseName << " operand";
734 }
735
736 if (operand.has_value()) {
737 if (parser.parseColonType(operandType))
738 return failure();
739 }
740
741 return success();
742}
743
744template <typename ClauseTypeAttr, typename ClauseType>
745static void
747 ClauseTypeAttr prescriptiveness, Value operand,
748 mlir::Type operandType,
749 StringRef (*stringifyClauseType)(ClauseType)) {
750
751 if (prescriptiveness)
752 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
753
754 if (operand)
755 p << operand << ": " << operandType;
756}
757
758//===----------------------------------------------------------------------===//
759// Parser and printer for grainsize Clause
760//===----------------------------------------------------------------------===//
761
762// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
763static ParseResult
764parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
765 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
766 Type &grainsizeType) {
768 parser, grainsizeMod, grainsize, grainsizeType,
769 &symbolizeClauseGrainsizeType, "grainsize");
770}
771
773 ClauseGrainsizeTypeAttr grainsizeMod,
774 Value grainsize, mlir::Type grainsizeType) {
776 p, op, grainsizeMod, grainsize, grainsizeType,
777 &stringifyClauseGrainsizeType);
778}
779
780//===----------------------------------------------------------------------===//
781// Parser and printer for num_tasks Clause
782//===----------------------------------------------------------------------===//
783
784// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
785static ParseResult
786parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
787 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
788 Type &numTasksType) {
790 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
791 "num_tasks");
792}
793
795 ClauseNumTasksTypeAttr numTasksMod,
796 Value numTasks, mlir::Type numTasksType) {
798 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
799}
800
801//===----------------------------------------------------------------------===//
802// Parsers for operations including clauses that define entry block arguments.
803//===----------------------------------------------------------------------===//
804
805namespace {
806struct MapParseArgs {
807 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
808 SmallVectorImpl<Type> &types;
809 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
810 SmallVectorImpl<Type> &types)
811 : vars(vars), types(types) {}
812};
813struct PrivateParseArgs {
814 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
815 llvm::SmallVectorImpl<Type> &types;
816 ArrayAttr &syms;
817 UnitAttr &needsBarrier;
818 DenseI64ArrayAttr *mapIndices;
819 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
820 SmallVectorImpl<Type> &types, ArrayAttr &syms,
821 UnitAttr &needsBarrier,
822 DenseI64ArrayAttr *mapIndices = nullptr)
823 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
824 mapIndices(mapIndices) {}
825};
826
827struct ReductionParseArgs {
828 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
829 SmallVectorImpl<Type> &types;
830 DenseBoolArrayAttr &byref;
831 ArrayAttr &syms;
832 ReductionModifierAttr *modifier;
833 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
834 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
835 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
836 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
837};
838
839struct AllRegionParseArgs {
840 std::optional<MapParseArgs> hasDeviceAddrArgs;
841 std::optional<MapParseArgs> hostEvalArgs;
842 std::optional<ReductionParseArgs> inReductionArgs;
843 std::optional<MapParseArgs> mapArgs;
844 std::optional<PrivateParseArgs> privateArgs;
845 std::optional<ReductionParseArgs> reductionArgs;
846 std::optional<ReductionParseArgs> taskReductionArgs;
847 std::optional<MapParseArgs> useDeviceAddrArgs;
848 std::optional<MapParseArgs> useDevicePtrArgs;
849};
850} // namespace
851
852static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
853 return "private_barrier";
854}
855
856static ParseResult parseClauseWithRegionArgs(
857 OpAsmParser &parser,
861 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
862 DenseBoolArrayAttr *byref = nullptr,
863 ReductionModifierAttr *modifier = nullptr,
864 UnitAttr *needsBarrier = nullptr) {
866 SmallVector<int64_t> mapIndicesVec;
867 SmallVector<bool> isByRefVec;
868 unsigned regionArgOffset = regionPrivateArgs.size();
869
870 if (parser.parseLParen())
871 return failure();
872
873 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
874 StringRef enumStr;
875 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
876 parser.parseComma())
877 return failure();
878 std::optional<ReductionModifier> enumValue =
879 symbolizeReductionModifier(enumStr);
880 if (!enumValue.has_value())
881 return failure();
882 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
883 if (!*modifier)
884 return failure();
885 }
886
887 if (parser.parseCommaSeparatedList([&]() {
888 if (byref)
889 isByRefVec.push_back(
890 parser.parseOptionalKeyword("byref").succeeded());
891
892 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
893 return failure();
894
895 if (parser.parseOperand(operands.emplace_back()) ||
896 parser.parseArrow() ||
897 parser.parseArgument(regionPrivateArgs.emplace_back()))
898 return failure();
899
900 if (mapIndices) {
901 if (parser.parseOptionalLSquare().succeeded()) {
902 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
903 parser.parseInteger(mapIndicesVec.emplace_back()) ||
904 parser.parseRSquare())
905 return failure();
906 } else {
907 mapIndicesVec.push_back(-1);
908 }
909 }
910
911 return success();
912 }))
913 return failure();
914
915 if (parser.parseColon())
916 return failure();
917
918 if (parser.parseCommaSeparatedList([&]() {
919 if (parser.parseType(types.emplace_back()))
920 return failure();
921
922 return success();
923 }))
924 return failure();
925
926 if (operands.size() != types.size())
927 return failure();
928
929 if (parser.parseRParen())
930 return failure();
931
932 if (needsBarrier) {
934 .succeeded())
935 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
936 }
937
938 auto *argsBegin = regionPrivateArgs.begin();
939 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
940 argsBegin + regionArgOffset + types.size());
941 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
942 prv.type = type;
943 }
944
945 if (symbols) {
946 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
947 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
948 }
949
950 if (!mapIndicesVec.empty())
951 *mapIndices =
952 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
953
954 if (byref)
955 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
956
957 return success();
958}
959
960static ParseResult parseBlockArgClause(
961 OpAsmParser &parser,
963 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
964 if (succeeded(parser.parseOptionalKeyword(keyword))) {
965 if (!mapArgs)
966 return failure();
967
968 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
969 entryBlockArgs)))
970 return failure();
971 }
972 return success();
973}
974
975static ParseResult parseBlockArgClause(
976 OpAsmParser &parser,
978 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
979 if (succeeded(parser.parseOptionalKeyword(keyword))) {
980 if (!privateArgs)
981 return failure();
982
983 if (failed(parseClauseWithRegionArgs(
984 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
985 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
986 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
987 return failure();
988 }
989 return success();
990}
991
992static ParseResult parseBlockArgClause(
993 OpAsmParser &parser,
995 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
996 if (succeeded(parser.parseOptionalKeyword(keyword))) {
997 if (!reductionArgs)
998 return failure();
999 if (failed(parseClauseWithRegionArgs(
1000 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1001 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1002 reductionArgs->modifier)))
1003 return failure();
1004 }
1005 return success();
1006}
1007
1008static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1009 AllRegionParseArgs args) {
1011
1012 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1013 args.hasDeviceAddrArgs)))
1014 return parser.emitError(parser.getCurrentLocation())
1015 << "invalid `has_device_addr` format";
1016
1017 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1018 args.hostEvalArgs)))
1019 return parser.emitError(parser.getCurrentLocation())
1020 << "invalid `host_eval` format";
1021
1022 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1023 args.inReductionArgs)))
1024 return parser.emitError(parser.getCurrentLocation())
1025 << "invalid `in_reduction` format";
1026
1027 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1028 args.mapArgs)))
1029 return parser.emitError(parser.getCurrentLocation())
1030 << "invalid `map_entries` format";
1031
1032 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1033 args.privateArgs)))
1034 return parser.emitError(parser.getCurrentLocation())
1035 << "invalid `private` format";
1036
1037 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1038 args.reductionArgs)))
1039 return parser.emitError(parser.getCurrentLocation())
1040 << "invalid `reduction` format";
1041
1042 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1043 args.taskReductionArgs)))
1044 return parser.emitError(parser.getCurrentLocation())
1045 << "invalid `task_reduction` format";
1046
1047 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1048 args.useDeviceAddrArgs)))
1049 return parser.emitError(parser.getCurrentLocation())
1050 << "invalid `use_device_addr` format";
1051
1052 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1053 args.useDevicePtrArgs)))
1054 return parser.emitError(parser.getCurrentLocation())
1055 << "invalid `use_device_addr` format";
1056
1057 return parser.parseRegion(region, entryBlockArgs);
1058}
1059
1060// These parseXyz functions correspond to the custom<Xyz> definitions
1061// in the .td file(s).
1062static ParseResult parseTargetOpRegion(
1063 OpAsmParser &parser, Region &region,
1065 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1067 SmallVectorImpl<Type> &hostEvalTypes,
1069 SmallVectorImpl<Type> &inReductionTypes,
1070 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1072 SmallVectorImpl<Type> &mapTypes,
1074 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1075 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1076 AllRegionParseArgs args;
1077 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1078 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1079 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1080 inReductionByref, inReductionSyms);
1081 args.mapArgs.emplace(mapVars, mapTypes);
1082 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1083 privateNeedsBarrier, &privateMaps);
1084 return parseBlockArgRegion(parser, region, args);
1085}
1086
1088 OpAsmParser &parser, Region &region,
1090 SmallVectorImpl<Type> &inReductionTypes,
1091 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1093 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1094 UnitAttr &privateNeedsBarrier) {
1095 AllRegionParseArgs args;
1096 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1097 inReductionByref, inReductionSyms);
1098 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1099 privateNeedsBarrier);
1100 return parseBlockArgRegion(parser, region, args);
1101}
1102
1104 OpAsmParser &parser, Region &region,
1106 SmallVectorImpl<Type> &inReductionTypes,
1107 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1109 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1110 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1112 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1113 ArrayAttr &reductionSyms) {
1114 AllRegionParseArgs args;
1115 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1116 inReductionByref, inReductionSyms);
1117 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1118 privateNeedsBarrier);
1119 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1120 reductionSyms, &reductionMod);
1121 return parseBlockArgRegion(parser, region, args);
1122}
1123
1124static ParseResult parsePrivateRegion(
1125 OpAsmParser &parser, Region &region,
1127 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1128 UnitAttr &privateNeedsBarrier) {
1129 AllRegionParseArgs args;
1130 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1131 privateNeedsBarrier);
1132 return parseBlockArgRegion(parser, region, args);
1133}
1134
1136 OpAsmParser &parser, Region &region,
1138 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1139 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1141 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1142 ArrayAttr &reductionSyms) {
1143 AllRegionParseArgs args;
1144 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1145 privateNeedsBarrier);
1146 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1147 reductionSyms, &reductionMod);
1148 return parseBlockArgRegion(parser, region, args);
1149}
1150
1151static ParseResult parseTaskReductionRegion(
1152 OpAsmParser &parser, Region &region,
1154 SmallVectorImpl<Type> &taskReductionTypes,
1155 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1156 AllRegionParseArgs args;
1157 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1158 taskReductionByref, taskReductionSyms);
1159 return parseBlockArgRegion(parser, region, args);
1160}
1161
1163 OpAsmParser &parser, Region &region,
1165 SmallVectorImpl<Type> &useDeviceAddrTypes,
1167 SmallVectorImpl<Type> &useDevicePtrTypes) {
1168 AllRegionParseArgs args;
1169 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1170 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1171 return parseBlockArgRegion(parser, region, args);
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// Printers for operations including clauses that define entry block arguments.
1176//===----------------------------------------------------------------------===//
1177
1178namespace {
1179struct MapPrintArgs {
1180 ValueRange vars;
1181 TypeRange types;
1182 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1183};
1184struct PrivatePrintArgs {
1185 ValueRange vars;
1186 TypeRange types;
1187 ArrayAttr syms;
1188 UnitAttr needsBarrier;
1189 DenseI64ArrayAttr mapIndices;
1190 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1191 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1192 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1193 mapIndices(mapIndices) {}
1194};
1195struct ReductionPrintArgs {
1196 ValueRange vars;
1197 TypeRange types;
1198 DenseBoolArrayAttr byref;
1199 ArrayAttr syms;
1200 ReductionModifierAttr modifier;
1201 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1202 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1203 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1204};
1205struct AllRegionPrintArgs {
1206 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1207 std::optional<MapPrintArgs> hostEvalArgs;
1208 std::optional<ReductionPrintArgs> inReductionArgs;
1209 std::optional<MapPrintArgs> mapArgs;
1210 std::optional<PrivatePrintArgs> privateArgs;
1211 std::optional<ReductionPrintArgs> reductionArgs;
1212 std::optional<ReductionPrintArgs> taskReductionArgs;
1213 std::optional<MapPrintArgs> useDeviceAddrArgs;
1214 std::optional<MapPrintArgs> useDevicePtrArgs;
1215};
1216} // namespace
1217
1219 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1220 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1221 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1222 DenseBoolArrayAttr byref = nullptr,
1223 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1224 if (argsSubrange.empty())
1225 return;
1226
1227 p << clauseName << "(";
1228
1229 if (modifier)
1230 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1231
1232 if (!symbols) {
1233 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1234 symbols = ArrayAttr::get(ctx, values);
1235 }
1236
1237 if (!mapIndices) {
1238 llvm::SmallVector<int64_t> values(operands.size(), -1);
1239 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1240 }
1241
1242 if (!byref) {
1243 mlir::SmallVector<bool> values(operands.size(), false);
1244 byref = DenseBoolArrayAttr::get(ctx, values);
1245 }
1246
1247 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1248 mapIndices.asArrayRef(),
1249 byref.asArrayRef()),
1250 p, [&p](auto t) {
1251 auto [op, arg, sym, map, isByRef] = t;
1252 if (isByRef)
1253 p << "byref ";
1254 if (sym)
1255 p << sym << " ";
1256
1257 p << op << " -> " << arg;
1258
1259 if (map != -1)
1260 p << " [map_idx=" << map << "]";
1261 });
1262 p << " : ";
1263 llvm::interleaveComma(types, p);
1264 p << ") ";
1265
1266 if (needsBarrier)
1267 p << getPrivateNeedsBarrierSpelling() << " ";
1268}
1269
1271 StringRef clauseName, ValueRange argsSubrange,
1272 std::optional<MapPrintArgs> mapArgs) {
1273 if (mapArgs)
1274 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1275 mapArgs->types);
1276}
1277
1279 StringRef clauseName, ValueRange argsSubrange,
1280 std::optional<PrivatePrintArgs> privateArgs) {
1281 if (privateArgs)
1283 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1284 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1285 /*modifier=*/nullptr, privateArgs->needsBarrier);
1286}
1287
1288static void
1289printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1290 ValueRange argsSubrange,
1291 std::optional<ReductionPrintArgs> reductionArgs) {
1292 if (reductionArgs)
1293 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1294 reductionArgs->vars, reductionArgs->types,
1295 reductionArgs->syms, /*mapIndices=*/nullptr,
1296 reductionArgs->byref, reductionArgs->modifier);
1297}
1298
1300 const AllRegionPrintArgs &args) {
1301 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1302 MLIRContext *ctx = op->getContext();
1303
1304 printBlockArgClause(p, ctx, "has_device_addr",
1305 iface.getHasDeviceAddrBlockArgs(),
1306 args.hasDeviceAddrArgs);
1307 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1308 args.hostEvalArgs);
1309 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1310 args.inReductionArgs);
1311 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1312 args.mapArgs);
1313 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1314 args.privateArgs);
1315 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1316 args.reductionArgs);
1317 printBlockArgClause(p, ctx, "task_reduction",
1318 iface.getTaskReductionBlockArgs(),
1319 args.taskReductionArgs);
1320 printBlockArgClause(p, ctx, "use_device_addr",
1321 iface.getUseDeviceAddrBlockArgs(),
1322 args.useDeviceAddrArgs);
1323 printBlockArgClause(p, ctx, "use_device_ptr",
1324 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1325
1326 p.printRegion(region, /*printEntryBlockArgs=*/false);
1327}
1328
1329// These parseXyz functions correspond to the custom<Xyz> definitions
1330// in the .td file(s).
1332 OpAsmPrinter &p, Operation *op, Region &region,
1333 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1334 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1335 ValueRange inReductionVars, TypeRange inReductionTypes,
1336 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1337 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1338 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1339 DenseI64ArrayAttr privateMaps) {
1340 AllRegionPrintArgs args;
1341 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1342 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1343 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1344 inReductionByref, inReductionSyms);
1345 args.mapArgs.emplace(mapVars, mapTypes);
1346 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1347 privateNeedsBarrier, privateMaps);
1348 printBlockArgRegion(p, op, region, args);
1349}
1350
1352 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1353 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1354 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1355 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1356 AllRegionPrintArgs args;
1357 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1358 inReductionByref, inReductionSyms);
1359 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1360 privateNeedsBarrier,
1361 /*mapIndices=*/nullptr);
1362 printBlockArgRegion(p, op, region, args);
1363}
1364
1366 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1367 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1368 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1369 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1370 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1371 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1372 ArrayAttr reductionSyms) {
1373 AllRegionPrintArgs args;
1374 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1375 inReductionByref, inReductionSyms);
1376 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1377 privateNeedsBarrier,
1378 /*mapIndices=*/nullptr);
1379 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1380 reductionSyms, reductionMod);
1381 printBlockArgRegion(p, op, region, args);
1382}
1383
1385 ValueRange privateVars, TypeRange privateTypes,
1386 ArrayAttr privateSyms,
1387 UnitAttr privateNeedsBarrier) {
1388 AllRegionPrintArgs args;
1389 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1390 privateNeedsBarrier,
1391 /*mapIndices=*/nullptr);
1392 printBlockArgRegion(p, op, region, args);
1393}
1394
1396 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1397 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1398 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1399 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1400 ArrayAttr reductionSyms) {
1401 AllRegionPrintArgs args;
1402 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1403 privateNeedsBarrier,
1404 /*mapIndices=*/nullptr);
1405 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1406 reductionSyms, reductionMod);
1407 printBlockArgRegion(p, op, region, args);
1408}
1409
1411 Region &region,
1412 ValueRange taskReductionVars,
1413 TypeRange taskReductionTypes,
1414 DenseBoolArrayAttr taskReductionByref,
1415 ArrayAttr taskReductionSyms) {
1416 AllRegionPrintArgs args;
1417 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1418 taskReductionByref, taskReductionSyms);
1419 printBlockArgRegion(p, op, region, args);
1420}
1421
1423 Region &region,
1424 ValueRange useDeviceAddrVars,
1425 TypeRange useDeviceAddrTypes,
1426 ValueRange useDevicePtrVars,
1427 TypeRange useDevicePtrTypes) {
1428 AllRegionPrintArgs args;
1429 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1430 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1431 printBlockArgRegion(p, op, region, args);
1432}
1433
1434/// Verifies Reduction Clause
1435static LogicalResult
1436verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1437 OperandRange reductionVars,
1438 std::optional<ArrayRef<bool>> reductionByref) {
1439 if (!reductionVars.empty()) {
1440 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1441 return op->emitOpError()
1442 << "expected as many reduction symbol references "
1443 "as reduction variables";
1444 if (reductionByref && reductionByref->size() != reductionVars.size())
1445 return op->emitError() << "expected as many reduction variable by "
1446 "reference attributes as reduction variables";
1447 } else {
1448 if (reductionSyms)
1449 return op->emitOpError() << "unexpected reduction symbol references";
1450 return success();
1451 }
1452
1453 // TODO: The followings should be done in
1454 // SymbolUserOpInterface::verifySymbolUses.
1455 DenseSet<Value> accumulators;
1456 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1457 Value accum = std::get<0>(args);
1458
1459 if (!accumulators.insert(accum).second)
1460 return op->emitOpError() << "accumulator variable used more than once";
1461
1462 Type varType = accum.getType();
1463 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1464 auto decl =
1466 if (!decl)
1467 return op->emitOpError() << "expected symbol reference " << symbolRef
1468 << " to point to a reduction declaration";
1469
1470 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1471 return op->emitOpError()
1472 << "expected accumulator (" << varType
1473 << ") to be the same type as reduction declaration ("
1474 << decl.getAccumulatorType() << ")";
1475 }
1476
1477 return success();
1478}
1479
1480//===----------------------------------------------------------------------===//
1481// Parser, printer and verifier for Copyprivate
1482//===----------------------------------------------------------------------===//
1483
1484/// copyprivate-entry-list ::= copyprivate-entry
1485/// | copyprivate-entry-list `,` copyprivate-entry
1486/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1487static ParseResult parseCopyprivate(
1488 OpAsmParser &parser,
1490 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1492 if (failed(parser.parseCommaSeparatedList([&]() {
1493 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1494 parser.parseArrow() ||
1495 parser.parseAttribute(symsVec.emplace_back()) ||
1496 parser.parseColonType(copyprivateTypes.emplace_back()))
1497 return failure();
1498 return success();
1499 })))
1500 return failure();
1501 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1502 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1503 return success();
1504}
1505
1506/// Print Copyprivate clause
1508 OperandRange copyprivateVars,
1509 TypeRange copyprivateTypes,
1510 std::optional<ArrayAttr> copyprivateSyms) {
1511 if (!copyprivateSyms.has_value())
1512 return;
1513 llvm::interleaveComma(
1514 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1515 [&](const auto &args) {
1516 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1517 << std::get<2>(args);
1518 });
1519}
1520
1521/// Verifies CopyPrivate Clause
1522static LogicalResult
1524 std::optional<ArrayAttr> copyprivateSyms) {
1525 size_t copyprivateSymsSize =
1526 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1527 if (copyprivateSymsSize != copyprivateVars.size())
1528 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1529 << copyprivateVars.size()
1530 << ") and functions (= " << copyprivateSymsSize
1531 << "), both must be equal";
1532 if (!copyprivateSyms.has_value())
1533 return success();
1534
1535 for (auto copyprivateVarAndSym :
1536 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1537 auto symbolRef =
1538 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1539 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1540 funcOp;
1541 if (mlir::func::FuncOp mlirFuncOp =
1543 symbolRef))
1544 funcOp = mlirFuncOp;
1545 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1547 op, symbolRef))
1548 funcOp = llvmFuncOp;
1549
1550 auto getNumArguments = [&] {
1551 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1552 };
1553
1554 auto getArgumentType = [&](unsigned i) {
1555 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1556 *funcOp);
1557 };
1558
1559 if (!funcOp)
1560 return op->emitOpError() << "expected symbol reference " << symbolRef
1561 << " to point to a copy function";
1562
1563 if (getNumArguments() != 2)
1564 return op->emitOpError()
1565 << "expected copy function " << symbolRef << " to have 2 operands";
1566
1567 Type argTy = getArgumentType(0);
1568 if (argTy != getArgumentType(1))
1569 return op->emitOpError() << "expected copy function " << symbolRef
1570 << " arguments to have the same type";
1571
1572 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1573 if (argTy != varType)
1574 return op->emitOpError()
1575 << "expected copy function arguments' type (" << argTy
1576 << ") to be the same as copyprivate variable's type (" << varType
1577 << ")";
1578 }
1579
1580 return success();
1581}
1582
1583//===----------------------------------------------------------------------===//
1584// Parser, printer and verifier for DependVarList
1585//===----------------------------------------------------------------------===//
1586
1587/// depend-entry-list ::= depend-entry
1588/// | depend-entry-list `,` depend-entry
1589/// depend-entry ::= depend-kind `->` ssa-id `:` type
1590static ParseResult
1593 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1595 if (failed(parser.parseCommaSeparatedList([&]() {
1596 StringRef keyword;
1597 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1598 parser.parseOperand(dependVars.emplace_back()) ||
1599 parser.parseColonType(dependTypes.emplace_back()))
1600 return failure();
1601 if (std::optional<ClauseTaskDepend> keywordDepend =
1602 (symbolizeClauseTaskDepend(keyword)))
1603 kindsVec.emplace_back(
1604 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1605 else
1606 return failure();
1607 return success();
1608 })))
1609 return failure();
1610 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1611 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1612 return success();
1613}
1614
1615/// Print Depend clause
1617 OperandRange dependVars, TypeRange dependTypes,
1618 std::optional<ArrayAttr> dependKinds) {
1619
1620 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1621 if (i != 0)
1622 p << ", ";
1623 p << stringifyClauseTaskDepend(
1624 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1625 .getValue())
1626 << " -> " << dependVars[i] << " : " << dependTypes[i];
1627 }
1628}
1629
1630/// Verifies Depend clause
1631static LogicalResult verifyDependVarList(Operation *op,
1632 std::optional<ArrayAttr> dependKinds,
1633 OperandRange dependVars) {
1634 if (!dependVars.empty()) {
1635 if (!dependKinds || dependKinds->size() != dependVars.size())
1636 return op->emitOpError() << "expected as many depend values"
1637 " as depend variables";
1638 } else {
1639 if (dependKinds && !dependKinds->empty())
1640 return op->emitOpError() << "unexpected depend values";
1641 return success();
1642 }
1643
1644 return success();
1645}
1646
1647//===----------------------------------------------------------------------===//
1648// Parser, printer and verifier for Synchronization Hint (2.17.12)
1649//===----------------------------------------------------------------------===//
1650
1651/// Parses a Synchronization Hint clause. The value of hint is an integer
1652/// which is a combination of different hints from `omp_sync_hint_t`.
1653///
1654/// hint-clause = `hint` `(` hint-value `)`
1655static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1656 IntegerAttr &hintAttr) {
1657 StringRef hintKeyword;
1658 int64_t hint = 0;
1659 if (succeeded(parser.parseOptionalKeyword("none"))) {
1660 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1661 return success();
1662 }
1663 auto parseKeyword = [&]() -> ParseResult {
1664 if (failed(parser.parseKeyword(&hintKeyword)))
1665 return failure();
1666 if (hintKeyword == "uncontended")
1667 hint |= 1;
1668 else if (hintKeyword == "contended")
1669 hint |= 2;
1670 else if (hintKeyword == "nonspeculative")
1671 hint |= 4;
1672 else if (hintKeyword == "speculative")
1673 hint |= 8;
1674 else
1675 return parser.emitError(parser.getCurrentLocation())
1676 << hintKeyword << " is not a valid hint";
1677 return success();
1678 };
1679 if (parser.parseCommaSeparatedList(parseKeyword))
1680 return failure();
1681 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1682 return success();
1683}
1684
1685/// Prints a Synchronization Hint clause
1687 IntegerAttr hintAttr) {
1688 int64_t hint = hintAttr.getInt();
1689
1690 if (hint == 0) {
1691 p << "none";
1692 return;
1693 }
1694
1695 // Helper function to get n-th bit from the right end of `value`
1696 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1697
1698 bool uncontended = bitn(hint, 0);
1699 bool contended = bitn(hint, 1);
1700 bool nonspeculative = bitn(hint, 2);
1701 bool speculative = bitn(hint, 3);
1702
1704 if (uncontended)
1705 hints.push_back("uncontended");
1706 if (contended)
1707 hints.push_back("contended");
1708 if (nonspeculative)
1709 hints.push_back("nonspeculative");
1710 if (speculative)
1711 hints.push_back("speculative");
1712
1713 llvm::interleaveComma(hints, p);
1714}
1715
1716/// Verifies a synchronization hint clause
1717static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1718
1719 // Helper function to get n-th bit from the right end of `value`
1720 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1721
1722 bool uncontended = bitn(hint, 0);
1723 bool contended = bitn(hint, 1);
1724 bool nonspeculative = bitn(hint, 2);
1725 bool speculative = bitn(hint, 3);
1726
1727 if (uncontended && contended)
1728 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1729 "omp_sync_hint_contended cannot be combined";
1730 if (nonspeculative && speculative)
1731 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1732 "omp_sync_hint_speculative cannot be combined.";
1733 return success();
1734}
1735
1736//===----------------------------------------------------------------------===//
1737// Parser, printer and verifier for Target
1738//===----------------------------------------------------------------------===//
1739
1740// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1741// boolean
1742static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1743 return (value & flag) == flag;
1744}
1745
1746/// Parses a map_entries map type from a string format back into its numeric
1747/// value.
1748///
1749/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1750/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1751static ParseResult parseMapClause(OpAsmParser &parser,
1752 ClauseMapFlagsAttr &mapType) {
1753 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1754 // This simply verifies the correct keyword is read in, the
1755 // keyword itself is stored inside of the operation
1756 auto parseTypeAndMod = [&]() -> ParseResult {
1757 StringRef mapTypeMod;
1758 if (parser.parseKeyword(&mapTypeMod))
1759 return failure();
1760
1761 if (mapTypeMod == "always")
1762 mapTypeBits |= ClauseMapFlags::always;
1763
1764 if (mapTypeMod == "implicit")
1765 mapTypeBits |= ClauseMapFlags::implicit;
1766
1767 if (mapTypeMod == "ompx_hold")
1768 mapTypeBits |= ClauseMapFlags::ompx_hold;
1769
1770 if (mapTypeMod == "close")
1771 mapTypeBits |= ClauseMapFlags::close;
1772
1773 if (mapTypeMod == "present")
1774 mapTypeBits |= ClauseMapFlags::present;
1775
1776 if (mapTypeMod == "to")
1777 mapTypeBits |= ClauseMapFlags::to;
1778
1779 if (mapTypeMod == "from")
1780 mapTypeBits |= ClauseMapFlags::from;
1781
1782 if (mapTypeMod == "tofrom")
1783 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1784
1785 if (mapTypeMod == "delete")
1786 mapTypeBits |= ClauseMapFlags::del;
1787
1788 if (mapTypeMod == "storage")
1789 mapTypeBits |= ClauseMapFlags::storage;
1790
1791 if (mapTypeMod == "return_param")
1792 mapTypeBits |= ClauseMapFlags::return_param;
1793
1794 if (mapTypeMod == "private")
1795 mapTypeBits |= ClauseMapFlags::priv;
1796
1797 if (mapTypeMod == "literal")
1798 mapTypeBits |= ClauseMapFlags::literal;
1799
1800 if (mapTypeMod == "attach")
1801 mapTypeBits |= ClauseMapFlags::attach;
1802
1803 if (mapTypeMod == "attach_always")
1804 mapTypeBits |= ClauseMapFlags::attach_always;
1805
1806 if (mapTypeMod == "attach_never")
1807 mapTypeBits |= ClauseMapFlags::attach_never;
1808
1809 if (mapTypeMod == "attach_auto")
1810 mapTypeBits |= ClauseMapFlags::attach_auto;
1811
1812 if (mapTypeMod == "ref_ptr")
1813 mapTypeBits |= ClauseMapFlags::ref_ptr;
1814
1815 if (mapTypeMod == "ref_ptee")
1816 mapTypeBits |= ClauseMapFlags::ref_ptee;
1817
1818 if (mapTypeMod == "ref_ptr_ptee")
1819 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1820
1821 if (mapTypeMod == "is_device_ptr")
1822 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1823
1824 return success();
1825 };
1826
1827 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1828 return failure();
1829
1830 mapType =
1831 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1832
1833 return success();
1834}
1835
1836/// Prints a map_entries map type from its numeric value out into its string
1837/// format.
1838static void printMapClause(OpAsmPrinter &p, Operation *op,
1839 ClauseMapFlagsAttr mapType) {
1841 ClauseMapFlags mapFlags = mapType.getValue();
1842
1843 // handling of always, close, present placed at the beginning of the string
1844 // to aid readability
1845 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
1846 mapTypeStrs.push_back("always");
1847 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
1848 mapTypeStrs.push_back("implicit");
1849 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
1850 mapTypeStrs.push_back("ompx_hold");
1851 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
1852 mapTypeStrs.push_back("close");
1853 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
1854 mapTypeStrs.push_back("present");
1855
1856 // special handling of to/from/tofrom/delete and release/alloc, release +
1857 // alloc are the abscense of one of the other flags, whereas tofrom requires
1858 // both the to and from flag to be set.
1859 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
1860 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
1861
1862 if (to && from)
1863 mapTypeStrs.push_back("tofrom");
1864 else if (from)
1865 mapTypeStrs.push_back("from");
1866 else if (to)
1867 mapTypeStrs.push_back("to");
1868
1869 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
1870 mapTypeStrs.push_back("delete");
1871 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
1872 mapTypeStrs.push_back("return_param");
1873 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
1874 mapTypeStrs.push_back("storage");
1875 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
1876 mapTypeStrs.push_back("private");
1877 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
1878 mapTypeStrs.push_back("literal");
1879 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
1880 mapTypeStrs.push_back("attach");
1881 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
1882 mapTypeStrs.push_back("attach_always");
1883 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
1884 mapTypeStrs.push_back("attach_never");
1885 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
1886 mapTypeStrs.push_back("attach_auto");
1887 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
1888 mapTypeStrs.push_back("ref_ptr");
1889 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
1890 mapTypeStrs.push_back("ref_ptee");
1891 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
1892 mapTypeStrs.push_back("ref_ptr_ptee");
1893 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1894 mapTypeStrs.push_back("is_device_ptr");
1895 if (mapFlags == ClauseMapFlags::none)
1896 mapTypeStrs.push_back("none");
1897
1898 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1899 p << mapTypeStrs[i];
1900 if (i + 1 < mapTypeStrs.size()) {
1901 p << ", ";
1902 }
1903 }
1904}
1905
1906static ParseResult parseMembersIndex(OpAsmParser &parser,
1907 ArrayAttr &membersIdx) {
1908 SmallVector<Attribute> values, memberIdxs;
1909
1910 auto parseIndices = [&]() -> ParseResult {
1911 int64_t value;
1912 if (parser.parseInteger(value))
1913 return failure();
1914 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1915 APInt(64, value, /*isSigned=*/false)));
1916 return success();
1917 };
1918
1919 do {
1920 if (failed(parser.parseLSquare()))
1921 return failure();
1922
1923 if (parser.parseCommaSeparatedList(parseIndices))
1924 return failure();
1925
1926 if (failed(parser.parseRSquare()))
1927 return failure();
1928
1929 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1930 values.clear();
1931 } while (succeeded(parser.parseOptionalComma()));
1932
1933 if (!memberIdxs.empty())
1934 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1935
1936 return success();
1937}
1938
1939static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1940 ArrayAttr membersIdx) {
1941 if (!membersIdx)
1942 return;
1943
1944 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1945 p << "[";
1946 auto memberIdx = cast<ArrayAttr>(v);
1947 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1948 p << cast<IntegerAttr>(v2).getInt();
1949 });
1950 p << "]";
1951 });
1952}
1953
1955 VariableCaptureKindAttr mapCaptureType) {
1956 std::string typeCapStr;
1957 llvm::raw_string_ostream typeCap(typeCapStr);
1958 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1959 typeCap << "ByRef";
1960 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1961 typeCap << "ByCopy";
1962 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1963 typeCap << "VLAType";
1964 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1965 typeCap << "This";
1966 p << typeCapStr;
1967}
1968
1969static ParseResult parseCaptureType(OpAsmParser &parser,
1970 VariableCaptureKindAttr &mapCaptureType) {
1971 StringRef mapCaptureKey;
1972 if (parser.parseKeyword(&mapCaptureKey))
1973 return failure();
1974
1975 if (mapCaptureKey == "This")
1976 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1977 parser.getContext(), mlir::omp::VariableCaptureKind::This);
1978 if (mapCaptureKey == "ByRef")
1979 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1980 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1981 if (mapCaptureKey == "ByCopy")
1982 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1983 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1984 if (mapCaptureKey == "VLAType")
1985 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1986 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1987
1988 return success();
1989}
1990
1991static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1994
1995 for (auto mapOp : mapVars) {
1996 if (!mapOp.getDefiningOp())
1997 return emitError(op->getLoc(), "missing map operation");
1998
1999 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2000 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2001
2002 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2003 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2004 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2005
2006 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2007 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2008 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2009
2010 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2011 return emitError(op->getLoc(),
2012 "to, from, tofrom and alloc map types are permitted");
2013
2014 if (isa<TargetEnterDataOp>(op) && (from || del))
2015 return emitError(op->getLoc(), "to and alloc map types are permitted");
2016
2017 if (isa<TargetExitDataOp>(op) && to)
2018 return emitError(op->getLoc(),
2019 "from, release and delete map types are permitted");
2020
2021 if (isa<TargetUpdateOp>(op)) {
2022 if (del) {
2023 return emitError(op->getLoc(),
2024 "at least one of to or from map types must be "
2025 "specified, other map types are not permitted");
2026 }
2027
2028 if (!to && !from) {
2029 return emitError(op->getLoc(),
2030 "at least one of to or from map types must be "
2031 "specified, other map types are not permitted");
2032 }
2033
2034 auto updateVar = mapInfoOp.getVarPtr();
2035
2036 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2037 (from && updateToVars.contains(updateVar))) {
2038 return emitError(
2039 op->getLoc(),
2040 "either to or from map types can be specified, not both");
2041 }
2042
2043 if (always || close || implicit) {
2044 return emitError(
2045 op->getLoc(),
2046 "present, mapper and iterator map type modifiers are permitted");
2047 }
2048
2049 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2050 }
2051 } else if (!isa<DeclareMapperInfoOp>(op)) {
2052 return emitError(op->getLoc(),
2053 "map argument is not a map entry operation");
2054 }
2055 }
2056
2057 return success();
2058}
2059
2060static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2061 std::optional<DenseI64ArrayAttr> privateMapIndices =
2062 targetOp.getPrivateMapsAttr();
2063
2064 // None of the private operands are mapped.
2065 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2066 return success();
2067
2068 OperandRange privateVars = targetOp.getPrivateVars();
2069
2070 if (privateMapIndices.value().size() !=
2071 static_cast<int64_t>(privateVars.size()))
2072 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2073 "`private_maps` attribute mismatch");
2074
2075 return success();
2076}
2077
2078//===----------------------------------------------------------------------===//
2079// MapInfoOp
2080//===----------------------------------------------------------------------===//
2081
2082static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2083 StringRef clauseName,
2084 OperandRange vars) {
2085 for (Value var : vars)
2086 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2087 return op->emitOpError()
2088 << "'" << clauseName
2089 << "' arguments must be defined by 'omp.map.info' ops";
2090 return success();
2091}
2092
2093LogicalResult MapInfoOp::verify() {
2094 if (getMapperId() &&
2096 *this, getMapperIdAttr())) {
2097 return emitError("invalid mapper id");
2098 }
2099
2100 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2101 return failure();
2102
2103 return success();
2104}
2105
2106//===----------------------------------------------------------------------===//
2107// TargetDataOp
2108//===----------------------------------------------------------------------===//
2109
2110void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2111 const TargetDataOperands &clauses) {
2112 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2113 clauses.mapVars, clauses.useDeviceAddrVars,
2114 clauses.useDevicePtrVars);
2115}
2116
2117LogicalResult TargetDataOp::verify() {
2118 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2119 getUseDeviceAddrVars().empty()) {
2120 return ::emitError(this->getLoc(),
2121 "At least one of map, use_device_ptr_vars, or "
2122 "use_device_addr_vars operand must be present");
2123 }
2124
2125 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2126 getUseDevicePtrVars())))
2127 return failure();
2128
2129 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2130 getUseDeviceAddrVars())))
2131 return failure();
2132
2133 return verifyMapClause(*this, getMapVars());
2134}
2135
2136//===----------------------------------------------------------------------===//
2137// TargetEnterDataOp
2138//===----------------------------------------------------------------------===//
2139
2140void TargetEnterDataOp::build(
2141 OpBuilder &builder, OperationState &state,
2142 const TargetEnterExitUpdateDataOperands &clauses) {
2143 MLIRContext *ctx = builder.getContext();
2144 TargetEnterDataOp::build(builder, state,
2145 makeArrayAttr(ctx, clauses.dependKinds),
2146 clauses.dependVars, clauses.device, clauses.ifExpr,
2147 clauses.mapVars, clauses.nowait);
2148}
2149
2150LogicalResult TargetEnterDataOp::verify() {
2151 LogicalResult verifyDependVars =
2152 verifyDependVarList(*this, getDependKinds(), getDependVars());
2153 return failed(verifyDependVars) ? verifyDependVars
2154 : verifyMapClause(*this, getMapVars());
2155}
2156
2157//===----------------------------------------------------------------------===//
2158// TargetExitDataOp
2159//===----------------------------------------------------------------------===//
2160
2161void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2162 const TargetEnterExitUpdateDataOperands &clauses) {
2163 MLIRContext *ctx = builder.getContext();
2164 TargetExitDataOp::build(builder, state,
2165 makeArrayAttr(ctx, clauses.dependKinds),
2166 clauses.dependVars, clauses.device, clauses.ifExpr,
2167 clauses.mapVars, clauses.nowait);
2168}
2169
2170LogicalResult TargetExitDataOp::verify() {
2171 LogicalResult verifyDependVars =
2172 verifyDependVarList(*this, getDependKinds(), getDependVars());
2173 return failed(verifyDependVars) ? verifyDependVars
2174 : verifyMapClause(*this, getMapVars());
2175}
2176
2177//===----------------------------------------------------------------------===//
2178// TargetUpdateOp
2179//===----------------------------------------------------------------------===//
2180
2181void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2182 const TargetEnterExitUpdateDataOperands &clauses) {
2183 MLIRContext *ctx = builder.getContext();
2184 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2185 clauses.dependVars, clauses.device, clauses.ifExpr,
2186 clauses.mapVars, clauses.nowait);
2187}
2188
2189LogicalResult TargetUpdateOp::verify() {
2190 LogicalResult verifyDependVars =
2191 verifyDependVarList(*this, getDependKinds(), getDependVars());
2192 return failed(verifyDependVars) ? verifyDependVars
2193 : verifyMapClause(*this, getMapVars());
2194}
2195
2196//===----------------------------------------------------------------------===//
2197// TargetOp
2198//===----------------------------------------------------------------------===//
2199
2200void TargetOp::build(OpBuilder &builder, OperationState &state,
2201 const TargetOperands &clauses) {
2202 MLIRContext *ctx = builder.getContext();
2203 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2204 // inReductionByref, inReductionSyms.
2205 TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2206 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
2207 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2208 clauses.hostEvalVars, clauses.ifExpr,
2209 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2210 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
2211 clauses.mapVars, clauses.nowait, clauses.privateVars,
2212 makeArrayAttr(ctx, clauses.privateSyms),
2213 clauses.privateNeedsBarrier, clauses.threadLimitVars,
2214 /*private_maps=*/nullptr);
2215}
2216
2217LogicalResult TargetOp::verify() {
2218 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
2219 return failure();
2220
2221 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2222 getHasDeviceAddrVars())))
2223 return failure();
2224
2225 if (failed(verifyMapClause(*this, getMapVars())))
2226 return failure();
2227
2228 return verifyPrivateVarsMapping(*this);
2229}
2230
2231LogicalResult TargetOp::verifyRegions() {
2232 auto teamsOps = getOps<TeamsOp>();
2233 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2234 return emitError("target containing multiple 'omp.teams' nested ops");
2235
2236 // Check that host_eval values are only used in legal ways.
2237 Operation *capturedOp = getInnermostCapturedOmpOp();
2238 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2239 for (Value hostEvalArg :
2240 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2241 for (Operation *user : hostEvalArg.getUsers()) {
2242 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2243 // Check if used in num_teams_lower or any of num_teams_upper_vars
2244 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2245 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2246 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2247 continue;
2248
2249 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2250 "and 'thread_limit' in 'omp.teams'";
2251 }
2252 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2253 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2254 parallelOp->isAncestor(capturedOp) &&
2255 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2256 continue;
2257
2258 return emitOpError()
2259 << "host_eval argument only legal as 'num_threads' in "
2260 "'omp.parallel' when representing target SPMD";
2261 }
2262 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2263 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2264 loopNestOp.getOperation() == capturedOp &&
2265 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2266 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2267 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2268 continue;
2269
2270 return emitOpError() << "host_eval argument only legal as loop bounds "
2271 "and steps in 'omp.loop_nest' when trip count "
2272 "must be evaluated in the host";
2273 }
2274
2275 return emitOpError() << "host_eval argument illegal use in '"
2276 << user->getName() << "' operation";
2277 }
2278 }
2279 return success();
2280}
2281
2282static Operation *
2283findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2284 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2285 assert(rootOp && "expected valid operation");
2286
2287 Dialect *ompDialect = rootOp->getDialect();
2288 Operation *capturedOp = nullptr;
2289 DominanceInfo domInfo;
2290
2291 // Process in pre-order to check operations from outermost to innermost,
2292 // ensuring we only enter the region of an operation if it meets the criteria
2293 // for being captured. We stop the exploration of nested operations as soon as
2294 // we process a region holding no operations to be captured.
2295 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2296 if (op == rootOp)
2297 return WalkResult::advance();
2298
2299 // Ignore operations of other dialects or omp operations with no regions,
2300 // because these will only be checked if they are siblings of an omp
2301 // operation that can potentially be captured.
2302 bool isOmpDialect = op->getDialect() == ompDialect;
2303 bool hasRegions = op->getNumRegions() > 0;
2304 if (!isOmpDialect || !hasRegions)
2305 return WalkResult::skip();
2306
2307 // This operation cannot be captured if it can be executed more than once
2308 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2309 // be executed before all exits of the region (i.e. it doesn't dominate all
2310 // blocks with no successors reachable from the entry block).
2311 if (checkSingleMandatoryExec) {
2312 Region *parentRegion = op->getParentRegion();
2313 Block *parentBlock = op->getBlock();
2314
2315 for (Block *successor : parentBlock->getSuccessors())
2316 if (successor->isReachable(parentBlock))
2317 return WalkResult::interrupt();
2318
2319 for (Block &block : *parentRegion)
2320 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2321 !domInfo.dominates(parentBlock, &block))
2322 return WalkResult::interrupt();
2323 }
2324
2325 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2326 // into nested operations.
2327 for (Operation &sibling : op->getParentRegion()->getOps())
2328 if (&sibling != op && !siblingAllowedFn(&sibling))
2329 return WalkResult::interrupt();
2330
2331 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2332 // Otherwise, process the contents of this operation.
2333 capturedOp = op;
2334 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2336 });
2337
2338 return capturedOp;
2339}
2340
2341Operation *TargetOp::getInnermostCapturedOmpOp() {
2342 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2343
2344 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2345 // effects, but don't include a memory write effect.
2346 return findCapturedOmpOp(
2347 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2348 if (!sibling)
2349 return false;
2350
2351 if (ompDialect == sibling->getDialect())
2352 return sibling->hasTrait<OpTrait::IsTerminator>();
2353
2354 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2356 effects;
2357 memOp.getEffects(effects);
2358 return !llvm::any_of(
2359 effects, [&](MemoryEffects::EffectInstance &effect) {
2360 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2361 isa<SideEffects::AutomaticAllocationScopeResource>(
2362 effect.getResource());
2363 });
2364 }
2365 return true;
2366 });
2367}
2368
2369/// Check if we can promote SPMD kernel to No-Loop kernel.
2370static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2371 WsloopOp *wsLoopOp) {
2372 // num_teams clause can break no-loop teams/threads assumption.
2373 if (!teamsOp.getNumTeamsUpperVars().empty())
2374 return false;
2375
2376 // Reduction kernels are slower in no-loop mode.
2377 if (teamsOp.getNumReductionVars())
2378 return false;
2379 if (wsLoopOp->getNumReductionVars())
2380 return false;
2381
2382 // Check if the user allows the promotion of kernels to no-loop mode.
2383 OffloadModuleInterface offloadMod =
2384 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2385 if (!offloadMod)
2386 return false;
2387 auto ompFlags = offloadMod.getFlags();
2388 if (!ompFlags)
2389 return false;
2390 return ompFlags.getAssumeTeamsOversubscription() &&
2391 ompFlags.getAssumeThreadsOversubscription();
2392}
2393
2394TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2395 // A non-null captured op is only valid if it resides inside of a TargetOp
2396 // and is the result of calling getInnermostCapturedOmpOp() on it.
2397 TargetOp targetOp =
2398 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2399 assert((!capturedOp ||
2400 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2401 "unexpected captured op");
2402
2403 // If it's not capturing a loop, it's a default target region.
2404 if (!isa_and_present<LoopNestOp>(capturedOp))
2405 return TargetRegionFlags::generic;
2406
2407 // Get the innermost non-simd loop wrapper.
2409 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2410 assert(!loopWrappers.empty());
2411
2412 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2413 if (isa<SimdOp>(innermostWrapper))
2414 innermostWrapper = std::next(innermostWrapper);
2415
2416 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2417 if (numWrappers != 1 && numWrappers != 2)
2418 return TargetRegionFlags::generic;
2419
2420 // Detect target-teams-distribute-parallel-wsloop[-simd].
2421 if (numWrappers == 2) {
2422 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2423 if (!wsloopOp)
2424 return TargetRegionFlags::generic;
2425
2426 innermostWrapper = std::next(innermostWrapper);
2427 if (!isa<DistributeOp>(innermostWrapper))
2428 return TargetRegionFlags::generic;
2429
2430 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2431 if (!isa_and_present<ParallelOp>(parallelOp))
2432 return TargetRegionFlags::generic;
2433
2434 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2435 if (!teamsOp)
2436 return TargetRegionFlags::generic;
2437
2438 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2439 TargetRegionFlags result =
2440 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2441 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2442 result = result | TargetRegionFlags::no_loop;
2443 return result;
2444 }
2445 }
2446 // Detect target-teams-distribute[-simd] and target-teams-loop.
2447 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2448 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2449 if (!isa_and_present<TeamsOp>(teamsOp))
2450 return TargetRegionFlags::generic;
2451
2452 if (teamsOp->getParentOp() != targetOp.getOperation())
2453 return TargetRegionFlags::generic;
2454
2455 if (isa<LoopOp>(innermostWrapper))
2456 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2457
2458 // Find single immediately nested captured omp.parallel and add spmd flag
2459 // (generic-spmd case).
2460 //
2461 // TODO: This shouldn't have to be done here, as it is too easy to break.
2462 // The openmp-opt pass should be updated to be able to promote kernels like
2463 // this from "Generic" to "Generic-SPMD". However, the use of the
2464 // `kmpc_distribute_static_loop` family of functions produced by the
2465 // OMPIRBuilder for these kernels prevents that from working.
2466 Dialect *ompDialect = targetOp->getDialect();
2467 Operation *nestedCapture = findCapturedOmpOp(
2468 capturedOp, /*checkSingleMandatoryExec=*/false,
2469 [&](Operation *sibling) {
2470 return sibling && (ompDialect != sibling->getDialect() ||
2471 sibling->hasTrait<OpTrait::IsTerminator>());
2472 });
2473
2474 TargetRegionFlags result =
2475 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2476
2477 if (!nestedCapture)
2478 return result;
2479
2480 while (nestedCapture->getParentOp() != capturedOp)
2481 nestedCapture = nestedCapture->getParentOp();
2482
2483 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2484 : result;
2485 }
2486 // Detect target-parallel-wsloop[-simd].
2487 else if (isa<WsloopOp>(innermostWrapper)) {
2488 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2489 if (!isa_and_present<ParallelOp>(parallelOp))
2490 return TargetRegionFlags::generic;
2491
2492 if (parallelOp->getParentOp() == targetOp.getOperation())
2493 return TargetRegionFlags::spmd;
2494 }
2495
2496 return TargetRegionFlags::generic;
2497}
2498
2499//===----------------------------------------------------------------------===//
2500// ParallelOp
2501//===----------------------------------------------------------------------===//
2502
2503void ParallelOp::build(OpBuilder &builder, OperationState &state,
2504 ArrayRef<NamedAttribute> attributes) {
2505 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2506 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2507 /*num_threads_vars=*/ValueRange(),
2508 /*private_vars=*/ValueRange(),
2509 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2510 /*proc_bind_kind=*/nullptr,
2511 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2512 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2513 state.addAttributes(attributes);
2514}
2515
2516void ParallelOp::build(OpBuilder &builder, OperationState &state,
2517 const ParallelOperands &clauses) {
2518 MLIRContext *ctx = builder.getContext();
2519 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2520 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2521 makeArrayAttr(ctx, clauses.privateSyms),
2522 clauses.privateNeedsBarrier, clauses.procBindKind,
2523 clauses.reductionMod, clauses.reductionVars,
2524 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2525 makeArrayAttr(ctx, clauses.reductionSyms));
2526}
2527
2528template <typename OpType>
2529static LogicalResult verifyPrivateVarList(OpType &op) {
2530 auto privateVars = op.getPrivateVars();
2531 auto privateSyms = op.getPrivateSymsAttr();
2532
2533 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2534 return success();
2535
2536 auto numPrivateVars = privateVars.size();
2537 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2538
2539 if (numPrivateVars != numPrivateSyms)
2540 return op.emitError() << "inconsistent number of private variables and "
2541 "privatizer op symbols, private vars: "
2542 << numPrivateVars
2543 << " vs. privatizer op symbols: " << numPrivateSyms;
2544
2545 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2546 Type varType = std::get<0>(privateVarInfo).getType();
2547 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2548 PrivateClauseOp privatizerOp =
2550
2551 if (privatizerOp == nullptr)
2552 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2553 << privateSym << "'";
2554
2555 Type privatizerType = privatizerOp.getArgType();
2556
2557 if (privatizerType && (varType != privatizerType))
2558 return op.emitError()
2559 << "type mismatch between a "
2560 << (privatizerOp.getDataSharingType() ==
2561 DataSharingClauseType::Private
2562 ? "private"
2563 : "firstprivate")
2564 << " variable and its privatizer op, var type: " << varType
2565 << " vs. privatizer op type: " << privatizerType;
2566 }
2567
2568 return success();
2569}
2570
2571LogicalResult ParallelOp::verify() {
2572 if (getAllocateVars().size() != getAllocatorVars().size())
2573 return emitError(
2574 "expected equal sizes for allocate and allocator variables");
2575
2576 if (failed(verifyPrivateVarList(*this)))
2577 return failure();
2578
2579 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2580 getReductionByref());
2581}
2582
2583LogicalResult ParallelOp::verifyRegions() {
2584 auto distChildOps = getOps<DistributeOp>();
2585 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2586 if (numDistChildOps > 1)
2587 return emitError()
2588 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2589
2590 if (numDistChildOps == 1) {
2591 if (!isComposite())
2592 return emitError()
2593 << "'omp.composite' attribute missing from composite operation";
2594
2595 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2596 Operation &distributeOp = **distChildOps.begin();
2597 for (Operation &childOp : getOps()) {
2598 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2599 continue;
2600
2601 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2602 return emitError() << "unexpected OpenMP operation inside of composite "
2603 "'omp.parallel': "
2604 << childOp.getName();
2605 }
2606 } else if (isComposite()) {
2607 return emitError()
2608 << "'omp.composite' attribute present in non-composite operation";
2609 }
2610 return success();
2611}
2612
2613//===----------------------------------------------------------------------===//
2614// TeamsOp
2615//===----------------------------------------------------------------------===//
2616
2618 while ((op = op->getParentOp()))
2619 if (isa<OpenMPDialect>(op->getDialect()))
2620 return false;
2621 return true;
2622}
2623
2624void TeamsOp::build(OpBuilder &builder, OperationState &state,
2625 const TeamsOperands &clauses) {
2626 MLIRContext *ctx = builder.getContext();
2627 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2628 TeamsOp::build(
2629 builder, state, clauses.allocateVars, clauses.allocatorVars,
2630 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2631 /*private_vars=*/{}, /*private_syms=*/nullptr,
2632 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2633 clauses.reductionVars,
2634 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2635 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2636}
2637
2638// Verify num_teams clause
2639static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2640 OperandRange numTeamsUpperVars) {
2641 // If lower is specified, upper must have exactly one value
2642 if (numTeamsLower) {
2643 if (numTeamsUpperVars.size() != 1)
2644 return op->emitError(
2645 "expected exactly one num_teams upper bound when lower bound is "
2646 "specified");
2647 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2648 return op->emitError(
2649 "expected num_teams upper bound and lower bound to be "
2650 "the same type");
2651 }
2652
2653 return success();
2654}
2655
2656LogicalResult TeamsOp::verify() {
2657 // Check parent region
2658 // TODO If nested inside of a target region, also check that it does not
2659 // contain any statements, declarations or directives other than this
2660 // omp.teams construct. The issue is how to support the initialization of
2661 // this operation's own arguments (allow SSA values across omp.target?).
2662 Operation *op = getOperation();
2663 if (!isa<TargetOp>(op->getParentOp()) &&
2665 return emitError("expected to be nested inside of omp.target or not nested "
2666 "in any OpenMP dialect operations");
2667
2668 // Check for num_teams clause restrictions
2669 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
2670 this->getNumTeamsUpperVars())))
2671 return failure();
2672
2673 // Check for allocate clause restrictions
2674 if (getAllocateVars().size() != getAllocatorVars().size())
2675 return emitError(
2676 "expected equal sizes for allocate and allocator variables");
2677
2678 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2679 getReductionByref());
2680}
2681
2682//===----------------------------------------------------------------------===//
2683// SectionOp
2684//===----------------------------------------------------------------------===//
2685
2686OperandRange SectionOp::getPrivateVars() {
2687 return getParentOp().getPrivateVars();
2688}
2689
2690OperandRange SectionOp::getReductionVars() {
2691 return getParentOp().getReductionVars();
2692}
2693
2694//===----------------------------------------------------------------------===//
2695// SectionsOp
2696//===----------------------------------------------------------------------===//
2697
2698void SectionsOp::build(OpBuilder &builder, OperationState &state,
2699 const SectionsOperands &clauses) {
2700 MLIRContext *ctx = builder.getContext();
2701 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2702 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2703 clauses.nowait, /*private_vars=*/{},
2704 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2705 clauses.reductionMod, clauses.reductionVars,
2706 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2707 makeArrayAttr(ctx, clauses.reductionSyms));
2708}
2709
2710LogicalResult SectionsOp::verify() {
2711 if (getAllocateVars().size() != getAllocatorVars().size())
2712 return emitError(
2713 "expected equal sizes for allocate and allocator variables");
2714
2715 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2716 getReductionByref());
2717}
2718
2719LogicalResult SectionsOp::verifyRegions() {
2720 for (auto &inst : *getRegion().begin()) {
2721 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2722 return emitOpError()
2723 << "expected omp.section op or terminator op inside region";
2724 }
2725 }
2726
2727 return success();
2728}
2729
2730//===----------------------------------------------------------------------===//
2731// SingleOp
2732//===----------------------------------------------------------------------===//
2733
2734void SingleOp::build(OpBuilder &builder, OperationState &state,
2735 const SingleOperands &clauses) {
2736 MLIRContext *ctx = builder.getContext();
2737 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2738 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2739 clauses.copyprivateVars,
2740 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2741 /*private_vars=*/{}, /*private_syms=*/nullptr,
2742 /*private_needs_barrier=*/nullptr);
2743}
2744
2745LogicalResult SingleOp::verify() {
2746 // Check for allocate clause restrictions
2747 if (getAllocateVars().size() != getAllocatorVars().size())
2748 return emitError(
2749 "expected equal sizes for allocate and allocator variables");
2750
2751 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2752 getCopyprivateSyms());
2753}
2754
2755//===----------------------------------------------------------------------===//
2756// WorkshareOp
2757//===----------------------------------------------------------------------===//
2758
2759void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2760 const WorkshareOperands &clauses) {
2761 WorkshareOp::build(builder, state, clauses.nowait);
2762}
2763
2764//===----------------------------------------------------------------------===//
2765// WorkshareLoopWrapperOp
2766//===----------------------------------------------------------------------===//
2767
2768LogicalResult WorkshareLoopWrapperOp::verify() {
2769 if (!(*this)->getParentOfType<WorkshareOp>())
2770 return emitOpError() << "must be nested in an omp.workshare";
2771 return success();
2772}
2773
2774LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2775 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2776 getNestedWrapper())
2777 return emitOpError() << "expected to be a standalone loop wrapper";
2778
2779 return success();
2780}
2781
2782//===----------------------------------------------------------------------===//
2783// LoopWrapperInterface
2784//===----------------------------------------------------------------------===//
2785
2786LogicalResult LoopWrapperInterface::verifyImpl() {
2787 Operation *op = this->getOperation();
2788 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2790 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2791 "and `SingleBlock` traits";
2792
2793 if (op->getNumRegions() != 1)
2794 return emitOpError() << "loop wrapper does not contain exactly one region";
2795
2796 Region &region = op->getRegion(0);
2797 if (range_size(region.getOps()) != 1)
2798 return emitOpError()
2799 << "loop wrapper does not contain exactly one nested op";
2800
2801 Operation &firstOp = *region.op_begin();
2802 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2803 return emitOpError() << "nested in loop wrapper is not another loop "
2804 "wrapper or `omp.loop_nest`";
2805
2806 return success();
2807}
2808
2809//===----------------------------------------------------------------------===//
2810// LoopOp
2811//===----------------------------------------------------------------------===//
2812
2813void LoopOp::build(OpBuilder &builder, OperationState &state,
2814 const LoopOperands &clauses) {
2815 MLIRContext *ctx = builder.getContext();
2816
2817 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2818 makeArrayAttr(ctx, clauses.privateSyms),
2819 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2820 clauses.reductionMod, clauses.reductionVars,
2821 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2822 makeArrayAttr(ctx, clauses.reductionSyms));
2823}
2824
2825LogicalResult LoopOp::verify() {
2826 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2827 getReductionByref());
2828}
2829
2830LogicalResult LoopOp::verifyRegions() {
2831 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2832 getNestedWrapper())
2833 return emitOpError() << "expected to be a standalone loop wrapper";
2834
2835 return success();
2836}
2837
2838//===----------------------------------------------------------------------===//
2839// WsloopOp
2840//===----------------------------------------------------------------------===//
2841
2842void WsloopOp::build(OpBuilder &builder, OperationState &state,
2843 ArrayRef<NamedAttribute> attributes) {
2844 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2845 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2846 /*linear_var_types*/ nullptr,
2847 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2848 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2849 /*private_needs_barrier=*/false,
2850 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2851 /*reduction_byref=*/nullptr,
2852 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2853 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2854 /*schedule_simd=*/false);
2855 state.addAttributes(attributes);
2856}
2857
2858void WsloopOp::build(OpBuilder &builder, OperationState &state,
2859 const WsloopOperands &clauses) {
2860 MLIRContext *ctx = builder.getContext();
2861 // TODO: Store clauses in op: allocateVars, allocatorVars
2862 WsloopOp::build(
2863 builder, state,
2864 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2865 clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait,
2866 clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars,
2867 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2868 clauses.reductionMod, clauses.reductionVars,
2869 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2870 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2871 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2872}
2873
2874LogicalResult WsloopOp::verify() {
2875 if (getLinearVars().size() &&
2876 getLinearVarTypes().value().size() != getLinearVars().size())
2877 return emitError() << "Ill-formed type attributes for linear variables";
2878 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2879 getReductionByref());
2880}
2881
2882LogicalResult WsloopOp::verifyRegions() {
2883 bool isCompositeChildLeaf =
2884 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2885
2886 if (LoopWrapperInterface nested = getNestedWrapper()) {
2887 if (!isComposite())
2888 return emitError()
2889 << "'omp.composite' attribute missing from composite wrapper";
2890
2891 // Check for the allowed leaf constructs that may appear in a composite
2892 // construct directly after DO/FOR.
2893 if (!isa<SimdOp>(nested))
2894 return emitError() << "only supported nested wrapper is 'omp.simd'";
2895
2896 } else if (isComposite() && !isCompositeChildLeaf) {
2897 return emitError()
2898 << "'omp.composite' attribute present in non-composite wrapper";
2899 } else if (!isComposite() && isCompositeChildLeaf) {
2900 return emitError()
2901 << "'omp.composite' attribute missing from composite wrapper";
2902 }
2903
2904 return success();
2905}
2906
2907//===----------------------------------------------------------------------===//
2908// Simd construct [2.9.3.1]
2909//===----------------------------------------------------------------------===//
2910
2911void SimdOp::build(OpBuilder &builder, OperationState &state,
2912 const SimdOperands &clauses) {
2913 MLIRContext *ctx = builder.getContext();
2914 SimdOp::build(
2915 builder, state, clauses.alignedVars,
2916 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2917 clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes,
2918 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2919 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2920 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2921 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2922 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2923 clauses.simdlen);
2924}
2925
2926LogicalResult SimdOp::verify() {
2927 if (getSimdlen().has_value() && getSafelen().has_value() &&
2928 getSimdlen().value() > getSafelen().value())
2929 return emitOpError()
2930 << "simdlen clause and safelen clause are both present, but the "
2931 "simdlen value is not less than or equal to safelen value";
2932
2933 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2934 return failure();
2935
2936 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2937 return failure();
2938
2939 bool isCompositeChildLeaf =
2940 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2941
2942 if (!isComposite() && isCompositeChildLeaf)
2943 return emitError()
2944 << "'omp.composite' attribute missing from composite wrapper";
2945
2946 if (isComposite() && !isCompositeChildLeaf)
2947 return emitError()
2948 << "'omp.composite' attribute present in non-composite wrapper";
2949
2950 // Firstprivate is not allowed for SIMD in the standard. Check that none of
2951 // the private decls are for firstprivate.
2952 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2953 if (privateSyms) {
2954 for (const Attribute &sym : *privateSyms) {
2955 auto symRef = cast<SymbolRefAttr>(sym);
2956 omp::PrivateClauseOp privatizer =
2958 getOperation(), symRef);
2959 if (!privatizer)
2960 return emitError() << "Cannot find privatizer '" << symRef << "'";
2961 if (privatizer.getDataSharingType() ==
2962 DataSharingClauseType::FirstPrivate)
2963 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2964 }
2965 }
2966
2967 if (getLinearVars().size() &&
2968 getLinearVarTypes().value().size() != getLinearVars().size())
2969 return emitError() << "Ill-formed type attributes for linear variables";
2970 return success();
2971}
2972
2973LogicalResult SimdOp::verifyRegions() {
2974 if (getNestedWrapper())
2975 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2976
2977 return success();
2978}
2979
2980//===----------------------------------------------------------------------===//
2981// Distribute construct [2.9.4.1]
2982//===----------------------------------------------------------------------===//
2983
2984void DistributeOp::build(OpBuilder &builder, OperationState &state,
2985 const DistributeOperands &clauses) {
2986 DistributeOp::build(builder, state, clauses.allocateVars,
2987 clauses.allocatorVars, clauses.distScheduleStatic,
2988 clauses.distScheduleChunkSize, clauses.order,
2989 clauses.orderMod, clauses.privateVars,
2990 makeArrayAttr(builder.getContext(), clauses.privateSyms),
2991 clauses.privateNeedsBarrier);
2992}
2993
2994LogicalResult DistributeOp::verify() {
2995 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2996 return emitOpError() << "chunk size set without "
2997 "dist_schedule_static being present";
2998
2999 if (getAllocateVars().size() != getAllocatorVars().size())
3000 return emitError(
3001 "expected equal sizes for allocate and allocator variables");
3002
3003 return success();
3004}
3005
3006LogicalResult DistributeOp::verifyRegions() {
3007 if (LoopWrapperInterface nested = getNestedWrapper()) {
3008 if (!isComposite())
3009 return emitError()
3010 << "'omp.composite' attribute missing from composite wrapper";
3011 // Check for the allowed leaf constructs that may appear in a composite
3012 // construct directly after DISTRIBUTE.
3013 if (isa<WsloopOp>(nested)) {
3014 Operation *parentOp = (*this)->getParentOp();
3015 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3016 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3017 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3018 "when a composite 'omp.parallel' is the direct "
3019 "parent";
3020 }
3021 } else if (!isa<SimdOp>(nested))
3022 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3023 "'omp.wsloop'";
3024 } else if (isComposite()) {
3025 return emitError()
3026 << "'omp.composite' attribute present in non-composite wrapper";
3027 }
3028
3029 return success();
3030}
3031
3032//===----------------------------------------------------------------------===//
3033// DeclareMapperOp / DeclareMapperInfoOp
3034//===----------------------------------------------------------------------===//
3035
3036LogicalResult DeclareMapperInfoOp::verify() {
3037 return verifyMapClause(*this, getMapVars());
3038}
3039
3040LogicalResult DeclareMapperOp::verifyRegions() {
3041 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3042 getRegion().getBlocks().front().getTerminator()))
3043 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3044
3045 return success();
3046}
3047
3048//===----------------------------------------------------------------------===//
3049// DeclareReductionOp
3050//===----------------------------------------------------------------------===//
3051
3052LogicalResult DeclareReductionOp::verifyRegions() {
3053 if (!getAllocRegion().empty()) {
3054 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3055 if (yieldOp.getResults().size() != 1 ||
3056 yieldOp.getResults().getTypes()[0] != getType())
3057 return emitOpError() << "expects alloc region to yield a value "
3058 "of the reduction type";
3059 }
3060 }
3061
3062 if (getInitializerRegion().empty())
3063 return emitOpError() << "expects non-empty initializer region";
3064 Block &initializerEntryBlock = getInitializerRegion().front();
3065
3066 if (initializerEntryBlock.getNumArguments() == 1) {
3067 if (!getAllocRegion().empty())
3068 return emitOpError() << "expects two arguments to the initializer region "
3069 "when an allocation region is used";
3070 } else if (initializerEntryBlock.getNumArguments() == 2) {
3071 if (getAllocRegion().empty())
3072 return emitOpError() << "expects one argument to the initializer region "
3073 "when no allocation region is used";
3074 } else {
3075 return emitOpError()
3076 << "expects one or two arguments to the initializer region";
3077 }
3078
3079 for (mlir::Value arg : initializerEntryBlock.getArguments())
3080 if (arg.getType() != getType())
3081 return emitOpError() << "expects initializer region argument to match "
3082 "the reduction type";
3083
3084 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3085 if (yieldOp.getResults().size() != 1 ||
3086 yieldOp.getResults().getTypes()[0] != getType())
3087 return emitOpError() << "expects initializer region to yield a value "
3088 "of the reduction type";
3089 }
3090
3091 if (getReductionRegion().empty())
3092 return emitOpError() << "expects non-empty reduction region";
3093 Block &reductionEntryBlock = getReductionRegion().front();
3094 if (reductionEntryBlock.getNumArguments() != 2 ||
3095 reductionEntryBlock.getArgumentTypes()[0] !=
3096 reductionEntryBlock.getArgumentTypes()[1] ||
3097 reductionEntryBlock.getArgumentTypes()[0] != getType())
3098 return emitOpError() << "expects reduction region with two arguments of "
3099 "the reduction type";
3100 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3101 if (yieldOp.getResults().size() != 1 ||
3102 yieldOp.getResults().getTypes()[0] != getType())
3103 return emitOpError() << "expects reduction region to yield a value "
3104 "of the reduction type";
3105 }
3106
3107 if (!getAtomicReductionRegion().empty()) {
3108 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3109 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3110 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3111 atomicReductionEntryBlock.getArgumentTypes()[1])
3112 return emitOpError() << "expects atomic reduction region with two "
3113 "arguments of the same type";
3114 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3115 atomicReductionEntryBlock.getArgumentTypes()[0]);
3116 if (!ptrType ||
3117 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3118 return emitOpError() << "expects atomic reduction region arguments to "
3119 "be accumulators containing the reduction type";
3120 }
3121
3122 if (getCleanupRegion().empty())
3123 return success();
3124 Block &cleanupEntryBlock = getCleanupRegion().front();
3125 if (cleanupEntryBlock.getNumArguments() != 1 ||
3126 cleanupEntryBlock.getArgument(0).getType() != getType())
3127 return emitOpError() << "expects cleanup region with one argument "
3128 "of the reduction type";
3129
3130 return success();
3131}
3132
3133//===----------------------------------------------------------------------===//
3134// TaskOp
3135//===----------------------------------------------------------------------===//
3136
3137void TaskOp::build(OpBuilder &builder, OperationState &state,
3138 const TaskOperands &clauses) {
3139 MLIRContext *ctx = builder.getContext();
3140 TaskOp::build(builder, state, clauses.affinityVars, clauses.allocateVars,
3141 clauses.allocatorVars, makeArrayAttr(ctx, clauses.dependKinds),
3142 clauses.dependVars, clauses.final, clauses.ifExpr,
3143 clauses.inReductionVars,
3144 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3145 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3146 clauses.priority, /*private_vars=*/clauses.privateVars,
3147 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3148 clauses.privateNeedsBarrier, clauses.untied,
3149 clauses.eventHandle);
3150}
3151
3152LogicalResult TaskOp::verify() {
3153 LogicalResult verifyDependVars =
3154 verifyDependVarList(*this, getDependKinds(), getDependVars());
3155 return failed(verifyDependVars)
3156 ? verifyDependVars
3157 : verifyReductionVarList(*this, getInReductionSyms(),
3158 getInReductionVars(),
3159 getInReductionByref());
3160}
3161
3162//===----------------------------------------------------------------------===//
3163// TaskgroupOp
3164//===----------------------------------------------------------------------===//
3165
3166void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3167 const TaskgroupOperands &clauses) {
3168 MLIRContext *ctx = builder.getContext();
3169 TaskgroupOp::build(builder, state, clauses.allocateVars,
3170 clauses.allocatorVars, clauses.taskReductionVars,
3171 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3172 makeArrayAttr(ctx, clauses.taskReductionSyms));
3173}
3174
3175LogicalResult TaskgroupOp::verify() {
3176 return verifyReductionVarList(*this, getTaskReductionSyms(),
3177 getTaskReductionVars(),
3178 getTaskReductionByref());
3179}
3180
3181//===----------------------------------------------------------------------===//
3182// TaskloopOp
3183//===----------------------------------------------------------------------===//
3184
3185void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3186 const TaskloopOperands &clauses) {
3187 MLIRContext *ctx = builder.getContext();
3188 TaskloopOp::build(
3189 builder, state, clauses.allocateVars, clauses.allocatorVars,
3190 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3191 clauses.inReductionVars,
3192 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3193 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3194 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3195 /*private_vars=*/clauses.privateVars,
3196 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3197 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3198 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3199 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3200}
3201
3202LogicalResult TaskloopOp::verify() {
3203 if (getAllocateVars().size() != getAllocatorVars().size())
3204 return emitError(
3205 "expected equal sizes for allocate and allocator variables");
3206 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3207 getReductionVars(), getReductionByref())) ||
3208 failed(verifyReductionVarList(*this, getInReductionSyms(),
3209 getInReductionVars(),
3210 getInReductionByref())))
3211 return failure();
3212
3213 if (!getReductionVars().empty() && getNogroup())
3214 return emitError("if a reduction clause is present on the taskloop "
3215 "directive, the nogroup clause must not be specified");
3216 for (auto var : getReductionVars()) {
3217 if (llvm::is_contained(getInReductionVars(), var))
3218 return emitError("the same list item cannot appear in both a reduction "
3219 "and an in_reduction clause");
3220 }
3221
3222 if (getGrainsize() && getNumTasks()) {
3223 return emitError(
3224 "the grainsize clause and num_tasks clause are mutually exclusive and "
3225 "may not appear on the same taskloop directive");
3226 }
3227
3228 return success();
3229}
3230
3231LogicalResult TaskloopOp::verifyRegions() {
3232 if (LoopWrapperInterface nested = getNestedWrapper()) {
3233 if (!isComposite())
3234 return emitError()
3235 << "'omp.composite' attribute missing from composite wrapper";
3236
3237 // Check for the allowed leaf constructs that may appear in a composite
3238 // construct directly after TASKLOOP.
3239 if (!isa<SimdOp>(nested))
3240 return emitError() << "only supported nested wrapper is 'omp.simd'";
3241 } else if (isComposite()) {
3242 return emitError()
3243 << "'omp.composite' attribute present in non-composite wrapper";
3244 }
3245
3246 return success();
3247}
3248
3249//===----------------------------------------------------------------------===//
3250// LoopNestOp
3251//===----------------------------------------------------------------------===//
3252
3253ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3254 // Parse an opening `(` followed by induction variables followed by `)`
3257 Type loopVarType;
3259 parser.parseColonType(loopVarType) ||
3260 // Parse loop bounds.
3261 parser.parseEqual() ||
3262 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3263 parser.parseKeyword("to") ||
3264 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3265 return failure();
3266
3267 for (auto &iv : ivs)
3268 iv.type = loopVarType;
3269
3270 auto *ctx = parser.getBuilder().getContext();
3271 // Parse "inclusive" flag.
3272 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3273 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3274
3275 // Parse step values.
3277 if (parser.parseKeyword("step") ||
3278 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3279 return failure();
3280
3281 // Parse collapse
3282 int64_t value = 0;
3283 if (!parser.parseOptionalKeyword("collapse") &&
3284 (parser.parseLParen() || parser.parseInteger(value) ||
3285 parser.parseRParen()))
3286 return failure();
3287 if (value > 1)
3288 result.addAttribute(
3289 "collapse_num_loops",
3290 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3291
3292 // Parse tiles
3294 auto parseTiles = [&]() -> ParseResult {
3295 int64_t tile;
3296 if (parser.parseInteger(tile))
3297 return failure();
3298 tiles.push_back(tile);
3299 return success();
3300 };
3301
3302 if (!parser.parseOptionalKeyword("tiles") &&
3303 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3304 parser.parseRParen()))
3305 return failure();
3306
3307 if (tiles.size() > 0)
3308 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3309
3310 // Parse the body.
3311 Region *region = result.addRegion();
3312 if (parser.parseRegion(*region, ivs))
3313 return failure();
3314
3315 // Resolve operands.
3316 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3317 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3318 parser.resolveOperands(steps, loopVarType, result.operands))
3319 return failure();
3320
3321 // Parse the optional attribute list.
3322 return parser.parseOptionalAttrDict(result.attributes);
3323}
3324
3325void LoopNestOp::print(OpAsmPrinter &p) {
3326 Region &region = getRegion();
3327 auto args = region.getArguments();
3328 p << " (" << args << ") : " << args[0].getType() << " = ("
3329 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3330 if (getLoopInclusive())
3331 p << "inclusive ";
3332 p << "step (" << getLoopSteps() << ") ";
3333 if (int64_t numCollapse = getCollapseNumLoops())
3334 if (numCollapse > 1)
3335 p << "collapse(" << numCollapse << ") ";
3336
3337 if (const auto tiles = getTileSizes())
3338 p << "tiles(" << tiles.value() << ") ";
3339
3340 p.printRegion(region, /*printEntryBlockArgs=*/false);
3341}
3342
3343void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3344 const LoopNestOperands &clauses) {
3345 MLIRContext *ctx = builder.getContext();
3346 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3347 clauses.loopLowerBounds, clauses.loopUpperBounds,
3348 clauses.loopSteps, clauses.loopInclusive,
3349 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3350}
3351
3352LogicalResult LoopNestOp::verify() {
3353 if (getLoopLowerBounds().empty())
3354 return emitOpError() << "must represent at least one loop";
3355
3356 if (getLoopLowerBounds().size() != getIVs().size())
3357 return emitOpError() << "number of range arguments and IVs do not match";
3358
3359 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3360 if (lb.getType() != iv.getType())
3361 return emitOpError()
3362 << "range argument type does not match corresponding IV type";
3363 }
3364
3365 uint64_t numIVs = getIVs().size();
3366
3367 if (const auto &numCollapse = getCollapseNumLoops())
3368 if (numCollapse > numIVs)
3369 return emitOpError()
3370 << "collapse value is larger than the number of loops";
3371
3372 if (const auto &tiles = getTileSizes())
3373 if (tiles.value().size() > numIVs)
3374 return emitOpError() << "too few canonical loops for tile dimensions";
3375
3376 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3377 return emitOpError() << "expects parent op to be a loop wrapper";
3378
3379 return success();
3380}
3381
3382void LoopNestOp::gatherWrappers(
3384 Operation *parent = (*this)->getParentOp();
3385 while (auto wrapper =
3386 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3387 wrappers.push_back(wrapper);
3388 parent = parent->getParentOp();
3389 }
3390}
3391
3392//===----------------------------------------------------------------------===//
3393// OpenMP canonical loop handling
3394//===----------------------------------------------------------------------===//
3395
3396std::tuple<NewCliOp, OpOperand *, OpOperand *>
3397mlir::omp ::decodeCli(Value cli) {
3398
3399 // Defining a CLI for a generated loop is optional; if there is none then
3400 // there is no followup-tranformation
3401 if (!cli)
3402 return {{}, nullptr, nullptr};
3403
3404 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3405 "Unexpected type of cli");
3406
3407 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3408 OpOperand *gen = nullptr;
3409 OpOperand *cons = nullptr;
3410 for (OpOperand &use : cli.getUses()) {
3411 auto op = cast<LoopTransformationInterface>(use.getOwner());
3412
3413 unsigned opnum = use.getOperandNumber();
3414 if (op.isGeneratee(opnum)) {
3415 assert(!gen && "Each CLI may have at most one def");
3416 gen = &use;
3417 } else if (op.isApplyee(opnum)) {
3418 assert(!cons && "Each CLI may have at most one consumer");
3419 cons = &use;
3420 } else {
3421 llvm_unreachable("Unexpected operand for a CLI");
3422 }
3423 }
3424
3425 return {create, gen, cons};
3426}
3427
3428void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3429 ::mlir::OperationState &odsState) {
3430 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3431}
3432
3433void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3434 Value result = getResult();
3435 auto [newCli, gen, cons] = decodeCli(result);
3436
3437 // Structured binding `gen` cannot be captured in lambdas before C++20
3438 OpOperand *generator = gen;
3439
3440 // Derive the CLI variable name from its generator:
3441 // * "canonloop" for omp.canonical_loop
3442 // * custom name for loop transformation generatees
3443 // * "cli" as fallback if no generator
3444 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3445 // at that level
3446 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3447 // the index of that region
3448 std::string cliName{"cli"};
3449 if (gen) {
3450 cliName =
3452 .Case([&](CanonicalLoopOp op) {
3453 return generateLoopNestingName("canonloop", op);
3454 })
3455 .Case([&](UnrollHeuristicOp op) -> std::string {
3456 llvm_unreachable("heuristic unrolling does not generate a loop");
3457 })
3458 .Case([&](FuseOp op) -> std::string {
3459 unsigned opnum = generator->getOperandNumber();
3460 // The position of the first loop to be fused is the same position
3461 // as the resulting fused loop
3462 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3463 return "canonloop_fuse";
3464 else
3465 return "fused";
3466 })
3467 .Case([&](TileOp op) -> std::string {
3468 auto [generateesFirst, generateesCount] =
3469 op.getGenerateesODSOperandIndexAndLength();
3470 unsigned firstGrid = generateesFirst;
3471 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3472 unsigned end = generateesFirst + generateesCount;
3473 unsigned opnum = generator->getOperandNumber();
3474 // In the OpenMP apply and looprange clauses, indices are 1-based
3475 if (firstGrid <= opnum && opnum < firstIntratile) {
3476 unsigned gridnum = opnum - firstGrid + 1;
3477 return ("grid" + Twine(gridnum)).str();
3478 }
3479 if (firstIntratile <= opnum && opnum < end) {
3480 unsigned intratilenum = opnum - firstIntratile + 1;
3481 return ("intratile" + Twine(intratilenum)).str();
3482 }
3483 llvm_unreachable("Unexpected generatee argument");
3484 })
3485 .DefaultUnreachable("TODO: Custom name for this operation");
3486 }
3487
3488 setNameFn(result, cliName);
3489}
3490
3491LogicalResult NewCliOp::verify() {
3492 Value cli = getResult();
3493
3494 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3495 "Unexpected type of cli");
3496
3497 // Check that the CLI is used in at most generator and one consumer
3498 OpOperand *gen = nullptr;
3499 OpOperand *cons = nullptr;
3500 for (mlir::OpOperand &use : cli.getUses()) {
3501 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3502
3503 unsigned opnum = use.getOperandNumber();
3504 if (op.isGeneratee(opnum)) {
3505 if (gen) {
3506 InFlightDiagnostic error =
3507 emitOpError("CLI must have at most one generator");
3508 error.attachNote(gen->getOwner()->getLoc())
3509 .append("first generator here:");
3510 error.attachNote(use.getOwner()->getLoc())
3511 .append("second generator here:");
3512 return error;
3513 }
3514
3515 gen = &use;
3516 } else if (op.isApplyee(opnum)) {
3517 if (cons) {
3518 InFlightDiagnostic error =
3519 emitOpError("CLI must have at most one consumer");
3520 error.attachNote(cons->getOwner()->getLoc())
3521 .append("first consumer here:")
3522 .appendOp(*cons->getOwner(),
3523 OpPrintingFlags().printGenericOpForm());
3524 error.attachNote(use.getOwner()->getLoc())
3525 .append("second consumer here:")
3526 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3527 return error;
3528 }
3529
3530 cons = &use;
3531 } else {
3532 llvm_unreachable("Unexpected operand for a CLI");
3533 }
3534 }
3535
3536 // If the CLI is source of a transformation, it must have a generator
3537 if (cons && !gen) {
3538 InFlightDiagnostic error = emitOpError("CLI has no generator");
3539 error.attachNote(cons->getOwner()->getLoc())
3540 .append("see consumer here: ")
3541 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3542 return error;
3543 }
3544
3545 return success();
3546}
3547
3548void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3549 Value tripCount) {
3550 odsState.addOperands(tripCount);
3551 odsState.addOperands(Value());
3552 (void)odsState.addRegion();
3553}
3554
3555void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3556 Value tripCount, ::mlir::Value cli) {
3557 odsState.addOperands(tripCount);
3558 odsState.addOperands(cli);
3559 (void)odsState.addRegion();
3560}
3561
3562void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3563 setNameFn(&getRegion().front(), "body_entry");
3564}
3565
3566void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3567 OpAsmSetValueNameFn setNameFn) {
3568 std::string ivName = generateLoopNestingName("iv", *this);
3569 setNameFn(region.getArgument(0), ivName);
3570}
3571
3572void CanonicalLoopOp::print(OpAsmPrinter &p) {
3573 if (getCli())
3574 p << '(' << getCli() << ')';
3575 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3576 << " in range(" << getTripCount() << ") ";
3577
3578 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3579 /*printBlockTerminators=*/true);
3580
3581 p.printOptionalAttrDict((*this)->getAttrs());
3582}
3583
3584mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3586 CanonicalLoopInfoType cliType =
3587 CanonicalLoopInfoType::get(parser.getContext());
3588
3589 // Parse (optional) omp.cli identifier
3591 SmallVector<mlir::Value, 1> cliOperand;
3592 if (!parser.parseOptionalLParen()) {
3593 if (parser.parseOperand(cli) ||
3594 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3595 return failure();
3596 }
3597
3598 // We derive the type of tripCount from inductionVariable. MLIR requires the
3599 // type of tripCount to be known when calling resolveOperand so we have parse
3600 // the type before processing the inductionVariable.
3601 OpAsmParser::Argument inductionVariable;
3603 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3604 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3605 parser.parseLParen() || parser.parseOperand(tripcount) ||
3606 parser.parseRParen() ||
3607 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3608 return failure();
3609
3610 // Parse the loop body.
3611 Region *region = result.addRegion();
3612 if (parser.parseRegion(*region, {inductionVariable}))
3613 return failure();
3614
3615 // We parsed the cli operand forst, but because it is optional, it must be
3616 // last in the operand list.
3617 result.operands.append(cliOperand);
3618
3619 // Parse the optional attribute list.
3620 if (parser.parseOptionalAttrDict(result.attributes))
3621 return failure();
3622
3623 return mlir::success();
3624}
3625
3626LogicalResult CanonicalLoopOp::verify() {
3627 // The region's entry must accept the induction variable
3628 // It can also be empty if just created
3629 if (!getRegion().empty()) {
3630 Region &region = getRegion();
3631 if (region.getNumArguments() != 1)
3632 return emitOpError(
3633 "Canonical loop region must have exactly one argument");
3634
3635 if (getInductionVar().getType() != getTripCount().getType())
3636 return emitOpError(
3637 "Region argument must be the same type as the trip count");
3638 }
3639
3640 return success();
3641}
3642
3643Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3644
3645std::pair<unsigned, unsigned>
3646CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3647 // No applyees
3648 return {0, 0};
3649}
3650
3651std::pair<unsigned, unsigned>
3652CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3653 return getODSOperandIndexAndLength(odsIndex_cli);
3654}
3655
3656//===----------------------------------------------------------------------===//
3657// UnrollHeuristicOp
3658//===----------------------------------------------------------------------===//
3659
3660void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3661 ::mlir::OperationState &odsState,
3662 ::mlir::Value cli) {
3663 odsState.addOperands(cli);
3664}
3665
3666void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3667 p << '(' << getApplyee() << ')';
3668
3669 p.printOptionalAttrDict((*this)->getAttrs());
3670}
3671
3672mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3674 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3675
3676 if (parser.parseLParen())
3677 return failure();
3678
3680 if (parser.parseOperand(applyee) ||
3681 parser.resolveOperand(applyee, cliType, result.operands))
3682 return failure();
3683
3684 if (parser.parseRParen())
3685 return failure();
3686
3687 // Optional output loop (full unrolling has none)
3688 if (!parser.parseOptionalArrow()) {
3689 if (parser.parseLParen() || parser.parseRParen())
3690 return failure();
3691 }
3692
3693 // Parse the optional attribute list.
3694 if (parser.parseOptionalAttrDict(result.attributes))
3695 return failure();
3696
3697 return mlir::success();
3698}
3699
3700std::pair<unsigned, unsigned>
3701UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3702 return getODSOperandIndexAndLength(odsIndex_applyee);
3703}
3704
3705std::pair<unsigned, unsigned>
3706UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3707 return {0, 0};
3708}
3709
3710//===----------------------------------------------------------------------===//
3711// TileOp
3712//===----------------------------------------------------------------------===//
3713
3714static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3715 OperandRange generatees,
3716 OperandRange applyees) {
3717 if (!generatees.empty())
3718 p << '(' << llvm::interleaved(generatees) << ')';
3719
3720 if (!applyees.empty())
3721 p << " <- (" << llvm::interleaved(applyees) << ')';
3722}
3723
3724static ParseResult parseLoopTransformClis(
3725 OpAsmParser &parser,
3728 if (parser.parseOptionalLess()) {
3729 // Syntax 1: generatees present
3730
3731 if (parser.parseOperandList(generateesOperands,
3733 return failure();
3734
3735 if (parser.parseLess())
3736 return failure();
3737 } else {
3738 // Syntax 2: generatees omitted
3739 }
3740
3741 // Parse `<-` (`<` has already been parsed)
3742 if (parser.parseMinus())
3743 return failure();
3744
3745 if (parser.parseOperandList(applyeesOperands,
3747 return failure();
3748
3749 return success();
3750}
3751
3752LogicalResult TileOp::verify() {
3753 if (getApplyees().empty())
3754 return emitOpError() << "must apply to at least one loop";
3755
3756 if (getSizes().size() != getApplyees().size())
3757 return emitOpError() << "there must be one tile size for each applyee";
3758
3759 if (!getGeneratees().empty() &&
3760 2 * getSizes().size() != getGeneratees().size())
3761 return emitOpError()
3762 << "expecting two times the number of generatees than applyees";
3763
3764 DenseSet<Value> parentIVs;
3765
3766 Value parent = getApplyees().front();
3767 for (auto &&applyee : llvm::drop_begin(getApplyees())) {
3768 auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
3769 auto [create, gen, cons] = decodeCli(applyee);
3770
3771 if (!parentGen)
3772 return emitOpError() << "applyee CLI has no generator";
3773
3774 auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3775 if (!parentGen)
3776 return emitOpError()
3777 << "currently only supports omp.canonical_loop as applyee";
3778
3779 parentIVs.insert(parentLoop.getInductionVar());
3780
3781 if (!gen)
3782 return emitOpError() << "applyee CLI has no generator";
3783 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3784 if (!loop)
3785 return emitOpError()
3786 << "currently only supports omp.canonical_loop as applyee";
3787
3788 // Canonical loop must be perfectly nested, i.e. the body of the parent must
3789 // only contain the omp.canonical_loop of the nested loops, and
3790 // omp.terminator
3791 bool isPerfectlyNested = [&]() {
3792 auto &parentBody = parentLoop.getRegion();
3793 if (!parentBody.hasOneBlock())
3794 return false;
3795 auto &parentBlock = parentBody.getBlocks().front();
3796
3797 auto nestedLoopIt = parentBlock.begin();
3798 if (nestedLoopIt == parentBlock.end() ||
3799 (&*nestedLoopIt != loop.getOperation()))
3800 return false;
3801
3802 auto termIt = std::next(nestedLoopIt);
3803 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3804 return false;
3805
3806 if (std::next(termIt) != parentBlock.end())
3807 return false;
3808
3809 return true;
3810 }();
3811 if (!isPerfectlyNested)
3812 return emitOpError() << "tiled loop nest must be perfectly nested";
3813
3814 if (parentIVs.contains(loop.getTripCount()))
3815 return emitOpError() << "tiled loop nest must be rectangular";
3816
3817 parent = applyee;
3818 }
3819
3820 // TODO: The tile sizes must be computed before the loop, but checking this
3821 // requires dominance analysis. For instance:
3822 //
3823 // %canonloop = omp.new_cli
3824 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3825 // // write to %x
3826 // omp.terminator
3827 // }
3828 // %ts = llvm.load %x
3829 // omp.tile <- (%canonloop) sizes(%ts : i32)
3830
3831 return success();
3832}
3833
3834std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3835 return getODSOperandIndexAndLength(odsIndex_applyees);
3836}
3837
3838std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3839 return getODSOperandIndexAndLength(odsIndex_generatees);
3840}
3841
3842//===----------------------------------------------------------------------===//
3843// FuseOp
3844//===----------------------------------------------------------------------===//
3845
3846static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
3847 OperandRange generatees,
3848 OperandRange applyees) {
3849 if (!generatees.empty())
3850 p << '(' << llvm::interleaved(generatees) << ')';
3851
3852 if (!applyees.empty())
3853 p << " <- (" << llvm::interleaved(applyees) << ')';
3854}
3855
3856LogicalResult FuseOp::verify() {
3857 if (getApplyees().size() < 2)
3858 return emitOpError() << "must apply to at least two loops";
3859
3860 if (getFirst().has_value() && getCount().has_value()) {
3861 int64_t first = getFirst().value();
3862 int64_t count = getCount().value();
3863 if ((unsigned)(first + count - 1) > getApplyees().size())
3864 return emitOpError() << "the numbers of applyees must be at least first "
3865 "minus one plus count attributes";
3866 if (!getGeneratees().empty() &&
3867 getGeneratees().size() != getApplyees().size() + 1 - count)
3868 return emitOpError() << "the number of generatees must be the number of "
3869 "aplyees plus one minus count";
3870
3871 } else {
3872 if (!getGeneratees().empty() && getGeneratees().size() != 1)
3873 return emitOpError()
3874 << "in a complete fuse the number of generatees must be exactly 1";
3875 }
3876 for (auto &&applyee : getApplyees()) {
3877 auto [create, gen, cons] = decodeCli(applyee);
3878
3879 if (!gen)
3880 return emitOpError() << "applyee CLI has no generator";
3881 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3882 if (!loop)
3883 return emitOpError()
3884 << "currently only supports omp.canonical_loop as applyee";
3885 }
3886 return success();
3887}
3888std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
3889 return getODSOperandIndexAndLength(odsIndex_applyees);
3890}
3891
3892std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
3893 return getODSOperandIndexAndLength(odsIndex_generatees);
3894}
3895
3896//===----------------------------------------------------------------------===//
3897// Critical construct (2.17.1)
3898//===----------------------------------------------------------------------===//
3899
3900void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3901 const CriticalDeclareOperands &clauses) {
3902 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3903}
3904
3905LogicalResult CriticalDeclareOp::verify() {
3906 return verifySynchronizationHint(*this, getHint());
3907}
3908
3909LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3910 if (getNameAttr()) {
3911 SymbolRefAttr symbolRef = getNameAttr();
3912 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3913 *this, symbolRef);
3914 if (!decl) {
3915 return emitOpError() << "expected symbol reference " << symbolRef
3916 << " to point to a critical declaration";
3917 }
3918 }
3919
3920 return success();
3921}
3922
3923//===----------------------------------------------------------------------===//
3924// Ordered construct
3925//===----------------------------------------------------------------------===//
3926
3927static LogicalResult verifyOrderedParent(Operation &op) {
3928 bool hasRegion = op.getNumRegions() > 0;
3929 auto loopOp = op.getParentOfType<LoopNestOp>();
3930 if (!loopOp) {
3931 if (hasRegion)
3932 return success();
3933
3934 // TODO: Consider if this needs to be the case only for the standalone
3935 // variant of the ordered construct.
3936 return op.emitOpError() << "must be nested inside of a loop";
3937 }
3938
3939 Operation *wrapper = loopOp->getParentOp();
3940 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3941 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3942 if (!orderedAttr)
3943 return op.emitOpError() << "the enclosing worksharing-loop region must "
3944 "have an ordered clause";
3945
3946 if (hasRegion && orderedAttr.getInt() != 0)
3947 return op.emitOpError() << "the enclosing loop's ordered clause must not "
3948 "have a parameter present";
3949
3950 if (!hasRegion && orderedAttr.getInt() == 0)
3951 return op.emitOpError() << "the enclosing loop's ordered clause must "
3952 "have a parameter present";
3953 } else if (!isa<SimdOp>(wrapper)) {
3954 return op.emitOpError() << "must be nested inside of a worksharing, simd "
3955 "or worksharing simd loop";
3956 }
3957 return success();
3958}
3959
3960void OrderedOp::build(OpBuilder &builder, OperationState &state,
3961 const OrderedOperands &clauses) {
3962 OrderedOp::build(builder, state, clauses.doacrossDependType,
3963 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3964}
3965
3966LogicalResult OrderedOp::verify() {
3967 if (failed(verifyOrderedParent(**this)))
3968 return failure();
3969
3970 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3971 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3972 return emitOpError() << "number of variables in depend clause does not "
3973 << "match number of iteration variables in the "
3974 << "doacross loop";
3975
3976 return success();
3977}
3978
3979void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3980 const OrderedRegionOperands &clauses) {
3981 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3982}
3983
3984LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3985
3986//===----------------------------------------------------------------------===//
3987// TaskwaitOp
3988//===----------------------------------------------------------------------===//
3989
3990void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3991 const TaskwaitOperands &clauses) {
3992 // TODO Store clauses in op: dependKinds, dependVars, nowait.
3993 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3994 /*depend_vars=*/{}, /*nowait=*/nullptr);
3995}
3996
3997//===----------------------------------------------------------------------===//
3998// Verifier for AtomicReadOp
3999//===----------------------------------------------------------------------===//
4000
4001LogicalResult AtomicReadOp::verify() {
4002 if (verifyCommon().failed())
4003 return mlir::failure();
4004
4005 if (auto mo = getMemoryOrder()) {
4006 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4007 *mo == ClauseMemoryOrderKind::Release) {
4008 return emitError(
4009 "memory-order must not be acq_rel or release for atomic reads");
4010 }
4011 }
4012 return verifySynchronizationHint(*this, getHint());
4013}
4014
4015//===----------------------------------------------------------------------===//
4016// Verifier for AtomicWriteOp
4017//===----------------------------------------------------------------------===//
4018
4019LogicalResult AtomicWriteOp::verify() {
4020 if (verifyCommon().failed())
4021 return mlir::failure();
4022
4023 if (auto mo = getMemoryOrder()) {
4024 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4025 *mo == ClauseMemoryOrderKind::Acquire) {
4026 return emitError(
4027 "memory-order must not be acq_rel or acquire for atomic writes");
4028 }
4029 }
4030 return verifySynchronizationHint(*this, getHint());
4031}
4032
4033//===----------------------------------------------------------------------===//
4034// Verifier for AtomicUpdateOp
4035//===----------------------------------------------------------------------===//
4036
4037LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4038 PatternRewriter &rewriter) {
4039 if (op.isNoOp()) {
4040 rewriter.eraseOp(op);
4041 return success();
4042 }
4043 if (Value writeVal = op.getWriteOpVal()) {
4044 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4045 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4046 return success();
4047 }
4048 return failure();
4049}
4050
4051LogicalResult AtomicUpdateOp::verify() {
4052 if (verifyCommon().failed())
4053 return mlir::failure();
4054
4055 if (auto mo = getMemoryOrder()) {
4056 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4057 *mo == ClauseMemoryOrderKind::Acquire) {
4058 return emitError(
4059 "memory-order must not be acq_rel or acquire for atomic updates");
4060 }
4061 }
4062
4063 return verifySynchronizationHint(*this, getHint());
4064}
4065
4066LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4067
4068//===----------------------------------------------------------------------===//
4069// Verifier for AtomicCaptureOp
4070//===----------------------------------------------------------------------===//
4071
4072AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4073 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4074 return op;
4075 return dyn_cast<AtomicReadOp>(getSecondOp());
4076}
4077
4078AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4079 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4080 return op;
4081 return dyn_cast<AtomicWriteOp>(getSecondOp());
4082}
4083
4084AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4085 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4086 return op;
4087 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4088}
4089
4090LogicalResult AtomicCaptureOp::verify() {
4091 return verifySynchronizationHint(*this, getHint());
4092}
4093
4094LogicalResult AtomicCaptureOp::verifyRegions() {
4095 if (verifyRegionsCommon().failed())
4096 return mlir::failure();
4097
4098 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4099 return emitOpError(
4100 "operations inside capture region must not have hint clause");
4101
4102 if (getFirstOp()->getAttr("memory_order") ||
4103 getSecondOp()->getAttr("memory_order"))
4104 return emitOpError(
4105 "operations inside capture region must not have memory_order clause");
4106 return success();
4107}
4108
4109//===----------------------------------------------------------------------===//
4110// CancelOp
4111//===----------------------------------------------------------------------===//
4112
4113void CancelOp::build(OpBuilder &builder, OperationState &state,
4114 const CancelOperands &clauses) {
4115 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4116}
4117
4119 Operation *parent = thisOp->getParentOp();
4120 while (parent) {
4121 if (parent->getDialect() == thisOp->getDialect())
4122 return parent;
4123 parent = parent->getParentOp();
4124 }
4125 return nullptr;
4126}
4127
4128LogicalResult CancelOp::verify() {
4129 ClauseCancellationConstructType cct = getCancelDirective();
4130 // The next OpenMP operation in the chain of parents
4131 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4132 if (!structuralParent)
4133 return emitOpError() << "Orphaned cancel construct";
4134
4135 if ((cct == ClauseCancellationConstructType::Parallel) &&
4136 !mlir::isa<ParallelOp>(structuralParent)) {
4137 return emitOpError() << "cancel parallel must appear "
4138 << "inside a parallel region";
4139 }
4140 if (cct == ClauseCancellationConstructType::Loop) {
4141 // structural parent will be omp.loop_nest, directly nested inside
4142 // omp.wsloop
4143 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4144
4145 if (!wsloopOp) {
4146 return emitOpError()
4147 << "cancel loop must appear inside a worksharing-loop region";
4148 }
4149 if (wsloopOp.getNowaitAttr()) {
4150 return emitError() << "A worksharing construct that is canceled "
4151 << "must not have a nowait clause";
4152 }
4153 if (wsloopOp.getOrderedAttr()) {
4154 return emitError() << "A worksharing construct that is canceled "
4155 << "must not have an ordered clause";
4156 }
4157
4158 } else if (cct == ClauseCancellationConstructType::Sections) {
4159 // structural parent will be an omp.section, directly nested inside
4160 // omp.sections
4161 auto sectionsOp =
4162 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4163 if (!sectionsOp) {
4164 return emitOpError() << "cancel sections must appear "
4165 << "inside a sections region";
4166 }
4167 if (sectionsOp.getNowait()) {
4168 return emitError() << "A sections construct that is canceled "
4169 << "must not have a nowait clause";
4170 }
4171 }
4172 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4173 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4174 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4175 return emitOpError() << "cancel taskgroup must appear "
4176 << "inside a task region";
4177 }
4178 return success();
4179}
4180
4181//===----------------------------------------------------------------------===//
4182// CancellationPointOp
4183//===----------------------------------------------------------------------===//
4184
4185void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4186 const CancellationPointOperands &clauses) {
4187 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4188}
4189
4190LogicalResult CancellationPointOp::verify() {
4191 ClauseCancellationConstructType cct = getCancelDirective();
4192 // The next OpenMP operation in the chain of parents
4193 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4194 if (!structuralParent)
4195 return emitOpError() << "Orphaned cancellation point";
4196
4197 if ((cct == ClauseCancellationConstructType::Parallel) &&
4198 !mlir::isa<ParallelOp>(structuralParent)) {
4199 return emitOpError() << "cancellation point parallel must appear "
4200 << "inside a parallel region";
4201 }
4202 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4203 // find the wsloop
4204 if ((cct == ClauseCancellationConstructType::Loop) &&
4205 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4206 return emitOpError() << "cancellation point loop must appear "
4207 << "inside a worksharing-loop region";
4208 }
4209 if ((cct == ClauseCancellationConstructType::Sections) &&
4210 !mlir::isa<omp::SectionOp>(structuralParent)) {
4211 return emitOpError() << "cancellation point sections must appear "
4212 << "inside a sections region";
4213 }
4214 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4215 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4216 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4217 return emitOpError() << "cancellation point taskgroup must appear "
4218 << "inside a task region";
4219 }
4220 return success();
4221}
4222
4223//===----------------------------------------------------------------------===//
4224// MapBoundsOp
4225//===----------------------------------------------------------------------===//
4226
4227LogicalResult MapBoundsOp::verify() {
4228 auto extent = getExtent();
4229 auto upperbound = getUpperBound();
4230 if (!extent && !upperbound)
4231 return emitError("expected extent or upperbound.");
4232 return success();
4233}
4234
4235void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4236 TypeRange /*result_types*/, StringAttr symName,
4237 TypeAttr type) {
4238 PrivateClauseOp::build(
4239 odsBuilder, odsState, symName, type,
4240 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4241 DataSharingClauseType::Private));
4242}
4243
4244LogicalResult PrivateClauseOp::verifyRegions() {
4245 Type argType = getArgType();
4246 auto verifyTerminator = [&](Operation *terminator,
4247 bool yieldsValue) -> LogicalResult {
4248 if (!terminator->getBlock()->getSuccessors().empty())
4249 return success();
4250
4251 if (!llvm::isa<YieldOp>(terminator))
4252 return mlir::emitError(terminator->getLoc())
4253 << "expected exit block terminator to be an `omp.yield` op.";
4254
4255 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4256 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4257
4258 if (!yieldsValue) {
4259 if (yieldedTypes.empty())
4260 return success();
4261
4262 return mlir::emitError(terminator->getLoc())
4263 << "Did not expect any values to be yielded.";
4264 }
4265
4266 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4267 return success();
4268
4269 auto error = mlir::emitError(yieldOp.getLoc())
4270 << "Invalid yielded value. Expected type: " << argType
4271 << ", got: ";
4272
4273 if (yieldedTypes.empty())
4274 error << "None";
4275 else
4276 error << yieldedTypes;
4277
4278 return error;
4279 };
4280
4281 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4282 StringRef regionName,
4283 bool yieldsValue) -> LogicalResult {
4284 assert(!region.empty());
4285
4286 if (region.getNumArguments() != expectedNumArgs)
4287 return mlir::emitError(region.getLoc())
4288 << "`" << regionName << "`: "
4289 << "expected " << expectedNumArgs
4290 << " region arguments, got: " << region.getNumArguments();
4291
4292 for (Block &block : region) {
4293 // MLIR will verify the absence of the terminator for us.
4294 if (!block.mightHaveTerminator())
4295 continue;
4296
4297 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4298 return failure();
4299 }
4300
4301 return success();
4302 };
4303
4304 // Ensure all of the region arguments have the same type
4305 for (Region *region : getRegions())
4306 for (Type ty : region->getArgumentTypes())
4307 if (ty != argType)
4308 return emitError() << "Region argument type mismatch: got " << ty
4309 << " expected " << argType << ".";
4310
4311 mlir::Region &initRegion = getInitRegion();
4312 if (!initRegion.empty() &&
4313 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4314 /*yieldsValue=*/true)))
4315 return failure();
4316
4317 DataSharingClauseType dsType = getDataSharingType();
4318
4319 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4320 return emitError("`private` clauses do not require a `copy` region.");
4321
4322 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4323 return emitError(
4324 "`firstprivate` clauses require at least a `copy` region.");
4325
4326 if (dsType == DataSharingClauseType::FirstPrivate &&
4327 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4328 /*yieldsValue=*/true)))
4329 return failure();
4330
4331 if (!getDeallocRegion().empty() &&
4332 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4333 /*yieldsValue=*/false)))
4334 return failure();
4335
4336 return success();
4337}
4338
4339//===----------------------------------------------------------------------===//
4340// Spec 5.2: Masked construct (10.5)
4341//===----------------------------------------------------------------------===//
4342
4343void MaskedOp::build(OpBuilder &builder, OperationState &state,
4344 const MaskedOperands &clauses) {
4345 MaskedOp::build(builder, state, clauses.filteredThreadId);
4346}
4347
4348//===----------------------------------------------------------------------===//
4349// Spec 5.2: Scan construct (5.6)
4350//===----------------------------------------------------------------------===//
4351
4352void ScanOp::build(OpBuilder &builder, OperationState &state,
4353 const ScanOperands &clauses) {
4354 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4355}
4356
4357LogicalResult ScanOp::verify() {
4358 if (hasExclusiveVars() == hasInclusiveVars())
4359 return emitError(
4360 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4361 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4362 if (parentWsLoopOp.getReductionModAttr() &&
4363 parentWsLoopOp.getReductionModAttr().getValue() ==
4364 ReductionModifier::inscan)
4365 return success();
4366 }
4367 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4368 if (parentSimdOp.getReductionModAttr() &&
4369 parentSimdOp.getReductionModAttr().getValue() ==
4370 ReductionModifier::inscan)
4371 return success();
4372 }
4373 return emitError("SCAN directive needs to be enclosed within a parent "
4374 "worksharing loop construct or SIMD construct with INSCAN "
4375 "reduction modifier");
4376}
4377
4378/// Verifies align clause in allocate directive
4379
4380LogicalResult AllocateDirOp::verify() {
4381 std::optional<uint64_t> align = this->getAlign();
4382
4383 if (align.has_value()) {
4384 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4385 return emitError() << "ALIGN value : " << align.value()
4386 << " must be power of 2";
4387 }
4388
4389 return success();
4390}
4391
4392//===----------------------------------------------------------------------===//
4393// TargetAllocMemOp
4394//===----------------------------------------------------------------------===//
4395
4396mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4397 return getInTypeAttr().getValue();
4398}
4399
4400/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4401/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4402/// attr-dict-without-keyword
4403static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4405 auto &builder = parser.getBuilder();
4406 bool hasOperands = false;
4407 std::int32_t typeparamsSize = 0;
4408
4409 // Parse device number as a new operand
4411 mlir::Type deviceType;
4412 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4413 return mlir::failure();
4414 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4415 return mlir::failure();
4416 if (parser.parseComma())
4417 return mlir::failure();
4418
4419 mlir::Type intype;
4420 if (parser.parseType(intype))
4421 return mlir::failure();
4422 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4425 if (!parser.parseOptionalLParen()) {
4426 // parse the LEN params of the derived type. (<params> : <types>)
4428 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4429 return mlir::failure();
4430 typeparamsSize = operands.size();
4431 hasOperands = true;
4432 }
4433 std::int32_t shapeSize = 0;
4434 if (!parser.parseOptionalComma()) {
4435 // parse size to scale by, vector of n dimensions of type index
4437 return mlir::failure();
4438 shapeSize = operands.size() - typeparamsSize;
4439 auto idxTy = builder.getIndexType();
4440 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4441 typeVec.push_back(idxTy);
4442 hasOperands = true;
4443 }
4444 if (hasOperands &&
4445 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4446 result.operands))
4447 return mlir::failure();
4448
4449 mlir::Type restype = builder.getIntegerType(64);
4450 if (!restype) {
4451 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4452 return mlir::failure();
4453 }
4454 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4455 result.addAttribute("operandSegmentSizes",
4456 builder.getDenseI32ArrayAttr(segmentSizes));
4457 if (parser.parseOptionalAttrDict(result.attributes) ||
4458 parser.addTypeToList(restype, result.types))
4459 return mlir::failure();
4460 return mlir::success();
4461}
4462
4463mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4465 return parseTargetAllocMemOp(parser, result);
4466}
4467
4468void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4469 p << " ";
4471 p << " : ";
4472 p << getDevice().getType();
4473 p << ", ";
4474 p << getInType();
4475 if (!getTypeparams().empty()) {
4476 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4477 }
4478 for (auto sh : getShape()) {
4479 p << ", ";
4480 p.printOperand(sh);
4481 }
4482 p.printOptionalAttrDict((*this)->getAttrs(),
4483 {"in_type", "operandSegmentSizes"});
4484}
4485
4486llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4487 mlir::Type outType = getType();
4488 if (!mlir::dyn_cast<IntegerType>(outType))
4489 return emitOpError("must be a integer type");
4490 return mlir::success();
4491}
4492
4493//===----------------------------------------------------------------------===//
4494// WorkdistributeOp
4495//===----------------------------------------------------------------------===//
4496
4497LogicalResult WorkdistributeOp::verify() {
4498 // Check that region exists and is not empty
4499 Region &region = getRegion();
4500 if (region.empty())
4501 return emitOpError("region cannot be empty");
4502 // Verify single entry point.
4503 Block &entryBlock = region.front();
4504 if (entryBlock.empty())
4505 return emitOpError("region must contain a structured block");
4506 // Verify single exit point.
4507 bool hasTerminator = false;
4508 for (Block &block : region) {
4509 if (isa<TerminatorOp>(block.back())) {
4510 if (hasTerminator) {
4511 return emitOpError("region must have exactly one terminator");
4512 }
4513 hasTerminator = true;
4514 }
4515 }
4516 if (!hasTerminator) {
4517 return emitOpError("region must be terminated with omp.terminator");
4518 }
4519 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4520 // No implicit barrier at end
4521 if (isa<BarrierOp>(op)) {
4522 return emitOpError(
4523 "explicit barriers are not allowed in workdistribute region");
4524 }
4525 // Check for invalid nested constructs
4526 if (isa<ParallelOp>(op)) {
4527 return emitOpError(
4528 "nested parallel constructs not allowed in workdistribute");
4529 }
4530 if (isa<TeamsOp>(op)) {
4531 return emitOpError(
4532 "nested teams constructs not allowed in workdistribute");
4533 }
4534 return WalkResult::advance();
4535 });
4536 if (walkResult.wasInterrupted())
4537 return failure();
4538
4539 Operation *parentOp = (*this)->getParentOp();
4540 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4541 return emitOpError("workdistribute must be nested under teams");
4542 return success();
4543}
4544
4545//===----------------------------------------------------------------------===//
4546// Declare simd [7.7]
4547//===----------------------------------------------------------------------===//
4548
4549LogicalResult DeclareSimdOp::verify() {
4550 // Must be nested inside a function-like op
4551 auto func =
4552 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4553 if (!func)
4554 return emitOpError() << "must be nested inside a function";
4555
4556 if (getInbranch() && getNotinbranch())
4557 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4558
4559 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4560}
4561
4562void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4563 const DeclareSimdOperands &clauses) {
4564 MLIRContext *ctx = odsBuilder.getContext();
4565 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4566 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4567 clauses.linearVars, clauses.linearStepVars,
4568 clauses.linearVarTypes, clauses.notinbranch,
4569 clauses.simdlen, clauses.uniformVars);
4570}
4571
4572//===----------------------------------------------------------------------===//
4573// Parser and printer for Uniform Clause
4574//===----------------------------------------------------------------------===//
4575
4576/// uniform ::= `uniform` `(` uniform-list `)`
4577/// uniform-list := uniform-val (`,` uniform-val)*
4578/// uniform-val := ssa-id `:` type
4579static ParseResult
4582 SmallVectorImpl<Type> &uniformTypes) {
4583 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
4584 if (parser.parseOperand(uniformVars.emplace_back()) ||
4585 parser.parseColonType(uniformTypes.emplace_back()))
4586 return mlir::failure();
4587 return mlir::success();
4588 });
4589}
4590
4591/// Print Uniform Clauses
4593 ValueRange uniformVars, TypeRange uniformTypes) {
4594 for (unsigned i = 0; i < uniformVars.size(); ++i) {
4595 if (i != 0)
4596 p << ", ";
4597 p << uniformVars[i] << " : " << uniformTypes[i];
4598 }
4599}
4600
4601//===----------------------------------------------------------------------===//
4602// Parser and printer for Affinity Clause
4603//===----------------------------------------------------------------------===//
4604
4605static ParseResult parseAffinityClause(
4606 OpAsmParser &parser,
4608 SmallVectorImpl<Type> &affinityTypes) {
4609 return parser.parseCommaSeparatedList([&]() -> ParseResult {
4610 if (parser.parseOperand(affinityVars.emplace_back()) ||
4611 parser.parseColonType(affinityTypes.emplace_back()))
4612 return failure();
4613 return success();
4614 });
4615}
4616
4618 ValueRange affinityVars,
4619 TypeRange affinityTypes) {
4620 for (unsigned i = 0; i < affinityVars.size(); ++i) {
4621 if (i)
4622 p << ", ";
4623 p << affinityVars[i] << " : " << affinityTypes[i];
4624 }
4625}
4626
4627#define GET_ATTRDEF_CLASSES
4628#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4629
4630#define GET_OP_CLASSES
4631#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4632
4633#define GET_TYPEDEF_CLASSES
4634#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1477
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange affinityVars, TypeRange affinityTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &affinityTypes)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
bool isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1294
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.