MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
23#include "mlir/IR/SymbolTable.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/PostOrderIterator.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/STLForwardCompat.h"
30#include "llvm/ADT/SmallString.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/ADT/bit.h"
35#include "llvm/Support/InterleavedRange.h"
36#include <cstddef>
37#include <iterator>
38#include <optional>
39#include <variant>
40
41#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45
46using namespace mlir;
47using namespace mlir::omp;
48
51 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52}
53
56 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57}
58
61 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
62}
63
64namespace {
65struct MemRefPointerLikeModel
66 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
67 MemRefType> {
68 Type getElementType(Type pointer) const {
69 return llvm::cast<MemRefType>(pointer).getElementType();
70 }
71};
72
73struct LLVMPointerPointerLikeModel
74 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
76 Type getElementType(Type pointer) const { return Type(); }
77};
78} // namespace
79
80/// Generate a name of a canonical loop nest of the format
81/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
82/// argument index of an operation that has multiple regions, if the operation
83/// has multiple regions.
84/// `_s<idx>` identifies the position of an operation within a region, where
85/// only operations that may potentially contain loops ("container operations"
86/// i.e. have region arguments) are counted. Again, it is omitted if there is
87/// only one such operation in a region. If there are canonical loops nested
88/// inside each other, also may also use the format `_d<num>` where <num> is the
89/// nesting depth of the loop.
90///
91/// The generated name is a best-effort to make canonical loop unique within an
92/// SSA namespace. This also means that regions with IsolatedFromAbove property
93/// do not consider any parents or siblings.
94static std::string generateLoopNestingName(StringRef prefix,
95 CanonicalLoopOp op) {
96 struct Component {
97 /// If true, this component describes a region operand of an operation (the
98 /// operand's owner) If false, this component describes an operation located
99 /// in a parent region
100 bool isRegionArgOfOp;
101 bool skip = false;
102 bool isUnique = false;
103
104 size_t idx;
105 Operation *op;
106 Region *parentRegion;
107 size_t loopDepth;
108
109 Operation *&getOwnerOp() {
110 assert(isRegionArgOfOp && "Must describe a region operand");
111 return op;
112 }
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp && "Must describe a region operand");
115 return idx;
116 }
117
118 Operation *&getContainerOp() {
119 assert(!isRegionArgOfOp && "Must describe a operation of a region");
120 return op;
121 }
122 size_t &getOpPos() {
123 assert(!isRegionArgOfOp && "Must describe a operation of a region");
124 return idx;
125 }
126 bool isLoopOp() const {
127 assert(!isRegionArgOfOp && "Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
129 }
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp && "Must describe a operation of a region");
132 return parentRegion;
133 }
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp && "Must describe a operation of a region");
136 return loopDepth;
137 }
138
139 void skipIf(bool v = true) { skip = skip || v; }
140 };
141
142 // List of ancestors, from inner to outer.
143 // Alternates between
144 // * region argument of an operation
145 // * operation within a region
146 SmallVector<Component> components;
147
148 // Gather a list of parent regions and operations, and the position within
149 // their parent
150 Operation *o = op.getOperation();
151 while (o) {
152 // Operation within a region
153 Region *r = o->getParentRegion();
154 if (!r)
155 break;
156
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
158 size_t idx = 0;
159 bool found = false;
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp = true;
162 for (Block *b : traversal) {
163 for (Operation &op : *b) {
164 if (&op == o && !found) {
165 sequentialIdx = idx;
166 found = true;
167 }
168 if (op.getNumRegions()) {
169 idx += 1;
170 if (idx > 1)
171 isOnlyContainerOp = false;
172 }
173 if (found && !isOnlyContainerOp)
174 break;
175 }
176 }
177
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp = false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
184
185 Operation *parent = r->getParentOp();
186
187 // Region argument of an operation
188 Component &regionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp = true;
190 regionArgOfOperation.isUnique = true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
193
194 // The IsolatedFromAbove trait of the parent operation implies that each
195 // individual region argument has its own separate namespace, so no
196 // ambiguity.
197 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
198 break;
199
200 // Component only needed if operation has multiple region operands. Region
201 // arguments may be optional, but we currently do not consider this.
202 if (parent->getRegions().size() > 1) {
203 auto getRegionIndex = [](Operation *o, Region *r) {
204 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
205 if (&region == r)
206 return idx;
207 }
208 llvm_unreachable("Region not child of its parent operation");
209 };
210 regionArgOfOperation.isUnique = false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
212 }
213
214 // next parent
215 o = parent;
216 }
217
218 // Determine whether a region-argument component is not needed
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
221
222 // Find runs of nested loops and determine each loop's depth in the loop nest
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
225 if (c.skip)
226 continue;
227
228 // non-skipped multi-argument operands interrupt the loop nest
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
231 continue;
232 }
233
234 // Multiple loops in a region means each of them is the outermost loop of a
235 // new loop nest
236 if (!c.isUnique)
237 numSurroundingLoops = 0;
238
239 c.getLoopDepth() = numSurroundingLoops;
240
241 // Next loop is surrounded by one more loop
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
244 }
245
246 // In loop nests, skip all but the innermost loop that contains the depth
247 // number
248 bool isLoopNest = false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
251 continue;
252
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
254 // Innermost loop of a loop nest of at least two loops
255 isLoopNest = true;
256 } else if (isLoopNest) {
257 // Non-innermost loop of a loop nest
258 c.skipIf(c.isUnique);
259
260 // If there is no surrounding loop left, this must have been the outermost
261 // loop; leave loop-nest mode for the next iteration
262 if (c.getLoopDepth() == 0)
263 isLoopNest = false;
264 }
265 }
266
267 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
268 // we mark them as skipped only after computing loop nests)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
272
273 // Components can be skipped if they are already disambiguated by their parent
274 // (or does not have a parent)
275 bool newRegion = true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
278
279 // non-skipped components disambiguate unique children
280 if (!c.skip)
281 newRegion = true;
282
283 // ...except canonical loops that need a suffix for each nest
284 if (!c.isRegionArgOfOp && c.getContainerOp())
285 newRegion = false;
286 }
287
288 // Compile the nesting name string
289 SmallString<64> Name{prefix};
290 llvm::raw_svector_ostream NameOS(Name);
291 for (auto &c : llvm::reverse(components)) {
292 if (c.skip)
293 continue;
294
295 if (c.isRegionArgOfOp)
296 NameOS << "_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS << "_d" << c.getLoopDepth();
299 else
300 NameOS << "_s" << c.getOpPos();
301 }
302
303 return NameOS.str().str();
304}
305
306void OpenMPDialect::initialize() {
307 addOperations<
308#define GET_OP_LIST
309#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
310 >();
311 addAttributes<
312#define GET_ATTRDEF_LIST
313#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
314 >();
315 addTypes<
316#define GET_TYPEDEF_LIST
317#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
318 >();
319
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
321
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
324 *getContext());
325
326 // Attach default offload module interface to module op to access
327 // offload functionality through
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
329 *getContext());
330
331 // Attach default declare target interfaces to operations which can be marked
332 // as declare target (Global Operations and Functions/Subroutines in dialects
333 // that Fortran (or other languages that lower to MLIR) translates too
334 mlir::LLVM::GlobalOp::attachInterface<
336 *getContext());
337 mlir::LLVM::LLVMFuncOp::attachInterface<
339 *getContext());
340 mlir::func::FuncOp::attachInterface<
342}
343
344//===----------------------------------------------------------------------===//
345// Parser and printer for Allocate Clause
346//===----------------------------------------------------------------------===//
347
348/// Parse an allocate clause with allocators and a list of operands with types.
349///
350/// allocate-operand-list :: = allocate-operand |
351/// allocator-operand `,` allocate-operand-list
352/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
353/// ssa-id-and-type ::= ssa-id `:` type
354static ParseResult parseAllocateAndAllocator(
355 OpAsmParser &parser,
357 SmallVectorImpl<Type> &allocateTypes,
359 SmallVectorImpl<Type> &allocatorTypes) {
360
361 return parser.parseCommaSeparatedList([&]() {
363 Type type;
364 if (parser.parseOperand(operand) || parser.parseColonType(type))
365 return failure();
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
368 if (parser.parseArrow())
369 return failure();
370 if (parser.parseOperand(operand) || parser.parseColonType(type))
371 return failure();
372
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
375 return success();
376 });
377}
378
379/// Print allocate clause
381 OperandRange allocateVars,
382 TypeRange allocateTypes,
383 OperandRange allocatorVars,
384 TypeRange allocatorTypes) {
385 for (unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
387 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
388 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
389 }
390}
391
392//===----------------------------------------------------------------------===//
393// Parser and printer for a clause attribute (StringEnumAttr)
394//===----------------------------------------------------------------------===//
395
396template <typename ClauseAttr>
397static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
398 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
399 StringRef enumStr;
400 SMLoc loc = parser.getCurrentLocation();
401 if (parser.parseKeyword(&enumStr))
402 return failure();
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404 attr = ClauseAttr::get(parser.getContext(), *enumValue);
405 return success();
406 }
407 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
408}
409
410template <typename ClauseAttr>
411static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
412 p << stringifyEnum(attr.getValue());
413}
414
415//===----------------------------------------------------------------------===//
416// Parser and printer for Linear Clause
417//===----------------------------------------------------------------------===//
418
419/// linear ::= `linear` `(` linear-list `)`
420/// linear-list := linear-val | linear-val linear-list
421/// linear-val := ssa-id-and-type `=` ssa-id-and-type
422static ParseResult parseLinearClause(
423 OpAsmParser &parser,
425 SmallVectorImpl<Type> &linearTypes,
427 SmallVectorImpl<Type> &linearStepTypes) {
428 return parser.parseCommaSeparatedList([&]() {
430 Type type, stepType;
432 if (parser.parseOperand(var) || parser.parseColonType(type) ||
433 parser.parseEqual() || parser.parseOperand(stepVar) ||
434 parser.parseColonType(stepType))
435 return failure();
436
437 linearVars.push_back(var);
438 linearTypes.push_back(type);
439 linearStepVars.push_back(stepVar);
440 linearStepTypes.push_back(stepType);
441 return success();
442 });
443}
444
445/// Print Linear Clause
447 ValueRange linearVars, TypeRange linearTypes,
448 ValueRange linearStepVars,
449 TypeRange stepVarTypes) {
450 size_t linearVarsSize = linearVars.size();
451 for (unsigned i = 0; i < linearVarsSize; ++i) {
452 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
453 p << linearVars[i] << " : " << linearTypes[i];
454 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
455 p << separator;
456 }
457}
458
459//===----------------------------------------------------------------------===//
460// Verifier for Nontemporal Clause
461//===----------------------------------------------------------------------===//
462
463static LogicalResult verifyNontemporalClause(Operation *op,
464 OperandRange nontemporalVars) {
465
466 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
467 DenseSet<Value> nontemporalItems;
468 for (const auto &it : nontemporalVars)
469 if (!nontemporalItems.insert(it).second)
470 return op->emitOpError() << "nontemporal variable used more than once";
471
472 return success();
473}
474
475//===----------------------------------------------------------------------===//
476// Parser, verifier and printer for Aligned Clause
477//===----------------------------------------------------------------------===//
478static LogicalResult verifyAlignedClause(Operation *op,
479 std::optional<ArrayAttr> alignments,
480 OperandRange alignedVars) {
481 // Check if number of alignment values equals to number of aligned variables
482 if (!alignedVars.empty()) {
483 if (!alignments || alignments->size() != alignedVars.size())
484 return op->emitOpError()
485 << "expected as many alignment values as aligned variables";
486 } else {
487 if (alignments)
488 return op->emitOpError() << "unexpected alignment values attribute";
489 return success();
490 }
491
492 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
493 DenseSet<Value> alignedItems;
494 for (auto it : alignedVars)
495 if (!alignedItems.insert(it).second)
496 return op->emitOpError() << "aligned variable used more than once";
497
498 if (!alignments)
499 return success();
500
501 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
502 for (unsigned i = 0; i < (*alignments).size(); ++i) {
503 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
504 if (intAttr.getValue().sle(0))
505 return op->emitOpError() << "alignment should be greater than 0";
506 } else {
507 return op->emitOpError() << "expected integer alignment";
508 }
509 }
510
511 return success();
512}
513
514/// aligned ::= `aligned` `(` aligned-list `)`
515/// aligned-list := aligned-val | aligned-val aligned-list
516/// aligned-val := ssa-id-and-type `->` alignment
517static ParseResult
520 SmallVectorImpl<Type> &alignedTypes,
521 ArrayAttr &alignmentsAttr) {
522 SmallVector<Attribute> alignmentVec;
523 if (failed(parser.parseCommaSeparatedList([&]() {
524 if (parser.parseOperand(alignedVars.emplace_back()) ||
525 parser.parseColonType(alignedTypes.emplace_back()) ||
526 parser.parseArrow() ||
527 parser.parseAttribute(alignmentVec.emplace_back())) {
528 return failure();
529 }
530 return success();
531 })))
532 return failure();
533 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
534 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
535 return success();
536}
537
538/// Print Aligned Clause
540 ValueRange alignedVars, TypeRange alignedTypes,
541 std::optional<ArrayAttr> alignments) {
542 for (unsigned i = 0; i < alignedVars.size(); ++i) {
543 if (i != 0)
544 p << ", ";
545 p << alignedVars[i] << " : " << alignedVars[i].getType();
546 p << " -> " << (*alignments)[i];
547 }
548}
549
550//===----------------------------------------------------------------------===//
551// Parser, printer and verifier for Schedule Clause
552//===----------------------------------------------------------------------===//
553
554static ParseResult
556 SmallVectorImpl<SmallString<12>> &modifiers) {
557 if (modifiers.size() > 2)
558 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
559 for (const auto &mod : modifiers) {
560 // Translate the string. If it has no value, then it was not a valid
561 // modifier!
562 auto symbol = symbolizeScheduleModifier(mod);
563 if (!symbol)
564 return parser.emitError(parser.getNameLoc())
565 << " unknown modifier type: " << mod;
566 }
567
568 // If we have one modifier that is "simd", then stick a "none" modiifer in
569 // index 0.
570 if (modifiers.size() == 1) {
571 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
572 modifiers.push_back(modifiers[0]);
573 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
574 }
575 } else if (modifiers.size() == 2) {
576 // If there are two modifier:
577 // First modifier should not be simd, second one should be simd
578 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
579 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
580 return parser.emitError(parser.getNameLoc())
581 << " incorrect modifier order";
582 }
583 return success();
584}
585
586/// schedule ::= `schedule` `(` sched-list `)`
587/// sched-list ::= sched-val | sched-val sched-list |
588/// sched-val `,` sched-modifier
589/// sched-val ::= sched-with-chunk | sched-wo-chunk
590/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
591/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
592/// sched-wo-chunk ::= `auto` | `runtime`
593/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
594/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
595static ParseResult
596parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
597 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
598 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
599 Type &chunkType) {
600 StringRef keyword;
601 if (parser.parseKeyword(&keyword))
602 return failure();
603 std::optional<mlir::omp::ClauseScheduleKind> schedule =
604 symbolizeClauseScheduleKind(keyword);
605 if (!schedule)
606 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
607
608 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
609 switch (*schedule) {
610 case ClauseScheduleKind::Static:
611 case ClauseScheduleKind::Dynamic:
612 case ClauseScheduleKind::Guided:
613 if (succeeded(parser.parseOptionalEqual())) {
614 chunkSize = OpAsmParser::UnresolvedOperand{};
615 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
616 return failure();
617 } else {
618 chunkSize = std::nullopt;
619 }
620 break;
621 case ClauseScheduleKind::Auto:
622 case ClauseScheduleKind::Runtime:
623 case ClauseScheduleKind::Distribute:
624 chunkSize = std::nullopt;
625 }
626
627 // If there is a comma, we have one or more modifiers..
629 while (succeeded(parser.parseOptionalComma())) {
630 StringRef mod;
631 if (parser.parseKeyword(&mod))
632 return failure();
633 modifiers.push_back(mod);
634 }
635
636 if (verifyScheduleModifiers(parser, modifiers))
637 return failure();
638
639 if (!modifiers.empty()) {
640 SMLoc loc = parser.getCurrentLocation();
641 if (std::optional<ScheduleModifier> mod =
642 symbolizeScheduleModifier(modifiers[0])) {
643 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
644 } else {
645 return parser.emitError(loc, "invalid schedule modifier");
646 }
647 // Only SIMD attribute is allowed here!
648 if (modifiers.size() > 1) {
649 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
650 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
651 }
652 }
653
654 return success();
655}
656
657/// Print schedule clause
659 ClauseScheduleKindAttr scheduleKind,
660 ScheduleModifierAttr scheduleMod,
661 UnitAttr scheduleSimd, Value scheduleChunk,
662 Type scheduleChunkType) {
663 p << stringifyClauseScheduleKind(scheduleKind.getValue());
664 if (scheduleChunk)
665 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
666 if (scheduleMod)
667 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
668 if (scheduleSimd)
669 p << ", simd";
670}
671
672//===----------------------------------------------------------------------===//
673// Parser and printer for Order Clause
674//===----------------------------------------------------------------------===//
675
676// order ::= `order` `(` [order-modifier ':'] concurrent `)`
677// order-modifier ::= reproducible | unconstrained
678static ParseResult parseOrderClause(OpAsmParser &parser,
679 ClauseOrderKindAttr &order,
680 OrderModifierAttr &orderMod) {
681 StringRef enumStr;
682 SMLoc loc = parser.getCurrentLocation();
683 if (parser.parseKeyword(&enumStr))
684 return failure();
685 if (std::optional<OrderModifier> enumValue =
686 symbolizeOrderModifier(enumStr)) {
687 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
688 if (parser.parseOptionalColon())
689 return failure();
690 loc = parser.getCurrentLocation();
691 if (parser.parseKeyword(&enumStr))
692 return failure();
693 }
694 if (std::optional<ClauseOrderKind> enumValue =
695 symbolizeClauseOrderKind(enumStr)) {
696 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
697 return success();
698 }
699 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
700}
701
703 ClauseOrderKindAttr order,
704 OrderModifierAttr orderMod) {
705 if (orderMod)
706 p << stringifyOrderModifier(orderMod.getValue()) << ":";
707 if (order)
708 p << stringifyClauseOrderKind(order.getValue());
709}
710
711template <typename ClauseTypeAttr, typename ClauseType>
712static ParseResult
713parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
714 std::optional<OpAsmParser::UnresolvedOperand> &operand,
715 Type &operandType,
716 std::optional<ClauseType> (*symbolizeClause)(StringRef),
717 StringRef clauseName) {
718 StringRef enumStr;
719 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
720 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
721 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
722 if (parser.parseComma())
723 return failure();
724 } else {
725 return parser.emitError(parser.getCurrentLocation())
726 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
727 ;
728 }
729 }
730
732 if (succeeded(parser.parseOperand(var))) {
733 operand = var;
734 } else {
735 return parser.emitError(parser.getCurrentLocation())
736 << "expected " << clauseName << " operand";
737 }
738
739 if (operand.has_value()) {
740 if (parser.parseColonType(operandType))
741 return failure();
742 }
743
744 return success();
745}
746
747template <typename ClauseTypeAttr, typename ClauseType>
748static void
750 ClauseTypeAttr prescriptiveness, Value operand,
751 mlir::Type operandType,
752 StringRef (*stringifyClauseType)(ClauseType)) {
753
754 if (prescriptiveness)
755 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
756
757 if (operand)
758 p << operand << ": " << operandType;
759}
760
761//===----------------------------------------------------------------------===//
762// Parser and printer for grainsize Clause
763//===----------------------------------------------------------------------===//
764
765// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
766static ParseResult
767parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
768 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
769 Type &grainsizeType) {
771 parser, grainsizeMod, grainsize, grainsizeType,
772 &symbolizeClauseGrainsizeType, "grainsize");
773}
774
776 ClauseGrainsizeTypeAttr grainsizeMod,
777 Value grainsize, mlir::Type grainsizeType) {
779 p, op, grainsizeMod, grainsize, grainsizeType,
780 &stringifyClauseGrainsizeType);
781}
782
783//===----------------------------------------------------------------------===//
784// Parser and printer for num_tasks Clause
785//===----------------------------------------------------------------------===//
786
787// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
788static ParseResult
789parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
790 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
791 Type &numTasksType) {
793 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
794 "num_tasks");
795}
796
798 ClauseNumTasksTypeAttr numTasksMod,
799 Value numTasks, mlir::Type numTasksType) {
801 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
802}
803
804//===----------------------------------------------------------------------===//
805// Parsers for operations including clauses that define entry block arguments.
806//===----------------------------------------------------------------------===//
807
808namespace {
809struct MapParseArgs {
810 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
811 SmallVectorImpl<Type> &types;
812 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
813 SmallVectorImpl<Type> &types)
814 : vars(vars), types(types) {}
815};
816struct PrivateParseArgs {
817 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
818 llvm::SmallVectorImpl<Type> &types;
819 ArrayAttr &syms;
820 UnitAttr &needsBarrier;
821 DenseI64ArrayAttr *mapIndices;
822 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
823 SmallVectorImpl<Type> &types, ArrayAttr &syms,
824 UnitAttr &needsBarrier,
825 DenseI64ArrayAttr *mapIndices = nullptr)
826 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
827 mapIndices(mapIndices) {}
828};
829
830struct ReductionParseArgs {
831 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
832 SmallVectorImpl<Type> &types;
833 DenseBoolArrayAttr &byref;
834 ArrayAttr &syms;
835 ReductionModifierAttr *modifier;
836 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
837 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
838 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
839 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
840};
841
842struct AllRegionParseArgs {
843 std::optional<MapParseArgs> hasDeviceAddrArgs;
844 std::optional<MapParseArgs> hostEvalArgs;
845 std::optional<ReductionParseArgs> inReductionArgs;
846 std::optional<MapParseArgs> mapArgs;
847 std::optional<PrivateParseArgs> privateArgs;
848 std::optional<ReductionParseArgs> reductionArgs;
849 std::optional<ReductionParseArgs> taskReductionArgs;
850 std::optional<MapParseArgs> useDeviceAddrArgs;
851 std::optional<MapParseArgs> useDevicePtrArgs;
852};
853} // namespace
854
855static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
856 return "private_barrier";
857}
858
859static ParseResult parseClauseWithRegionArgs(
860 OpAsmParser &parser,
864 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
865 DenseBoolArrayAttr *byref = nullptr,
866 ReductionModifierAttr *modifier = nullptr,
867 UnitAttr *needsBarrier = nullptr) {
869 SmallVector<int64_t> mapIndicesVec;
870 SmallVector<bool> isByRefVec;
871 unsigned regionArgOffset = regionPrivateArgs.size();
872
873 if (parser.parseLParen())
874 return failure();
875
876 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
877 StringRef enumStr;
878 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
879 parser.parseComma())
880 return failure();
881 std::optional<ReductionModifier> enumValue =
882 symbolizeReductionModifier(enumStr);
883 if (!enumValue.has_value())
884 return failure();
885 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
886 if (!*modifier)
887 return failure();
888 }
889
890 if (parser.parseCommaSeparatedList([&]() {
891 if (byref)
892 isByRefVec.push_back(
893 parser.parseOptionalKeyword("byref").succeeded());
894
895 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
896 return failure();
897
898 if (parser.parseOperand(operands.emplace_back()) ||
899 parser.parseArrow() ||
900 parser.parseArgument(regionPrivateArgs.emplace_back()))
901 return failure();
902
903 if (mapIndices) {
904 if (parser.parseOptionalLSquare().succeeded()) {
905 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
906 parser.parseInteger(mapIndicesVec.emplace_back()) ||
907 parser.parseRSquare())
908 return failure();
909 } else {
910 mapIndicesVec.push_back(-1);
911 }
912 }
913
914 return success();
915 }))
916 return failure();
917
918 if (parser.parseColon())
919 return failure();
920
921 if (parser.parseCommaSeparatedList([&]() {
922 if (parser.parseType(types.emplace_back()))
923 return failure();
924
925 return success();
926 }))
927 return failure();
928
929 if (operands.size() != types.size())
930 return failure();
931
932 if (parser.parseRParen())
933 return failure();
934
935 if (needsBarrier) {
937 .succeeded())
938 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
939 }
940
941 auto *argsBegin = regionPrivateArgs.begin();
942 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
943 argsBegin + regionArgOffset + types.size());
944 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
945 prv.type = type;
946 }
947
948 if (symbols) {
949 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
950 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
951 }
952
953 if (!mapIndicesVec.empty())
954 *mapIndices =
955 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
956
957 if (byref)
958 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
959
960 return success();
961}
962
963static ParseResult parseBlockArgClause(
964 OpAsmParser &parser,
966 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
967 if (succeeded(parser.parseOptionalKeyword(keyword))) {
968 if (!mapArgs)
969 return failure();
970
971 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
972 entryBlockArgs)))
973 return failure();
974 }
975 return success();
976}
977
978static ParseResult parseBlockArgClause(
979 OpAsmParser &parser,
981 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
982 if (succeeded(parser.parseOptionalKeyword(keyword))) {
983 if (!privateArgs)
984 return failure();
985
986 if (failed(parseClauseWithRegionArgs(
987 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
988 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
989 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
990 return failure();
991 }
992 return success();
993}
994
995static ParseResult parseBlockArgClause(
996 OpAsmParser &parser,
998 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
999 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1000 if (!reductionArgs)
1001 return failure();
1002 if (failed(parseClauseWithRegionArgs(
1003 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1004 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1005 reductionArgs->modifier)))
1006 return failure();
1007 }
1008 return success();
1009}
1010
1011static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1012 AllRegionParseArgs args) {
1014
1015 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1016 args.hasDeviceAddrArgs)))
1017 return parser.emitError(parser.getCurrentLocation())
1018 << "invalid `has_device_addr` format";
1019
1020 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1021 args.hostEvalArgs)))
1022 return parser.emitError(parser.getCurrentLocation())
1023 << "invalid `host_eval` format";
1024
1025 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1026 args.inReductionArgs)))
1027 return parser.emitError(parser.getCurrentLocation())
1028 << "invalid `in_reduction` format";
1029
1030 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1031 args.mapArgs)))
1032 return parser.emitError(parser.getCurrentLocation())
1033 << "invalid `map_entries` format";
1034
1035 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1036 args.privateArgs)))
1037 return parser.emitError(parser.getCurrentLocation())
1038 << "invalid `private` format";
1039
1040 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1041 args.reductionArgs)))
1042 return parser.emitError(parser.getCurrentLocation())
1043 << "invalid `reduction` format";
1044
1045 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1046 args.taskReductionArgs)))
1047 return parser.emitError(parser.getCurrentLocation())
1048 << "invalid `task_reduction` format";
1049
1050 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1051 args.useDeviceAddrArgs)))
1052 return parser.emitError(parser.getCurrentLocation())
1053 << "invalid `use_device_addr` format";
1054
1055 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1056 args.useDevicePtrArgs)))
1057 return parser.emitError(parser.getCurrentLocation())
1058 << "invalid `use_device_addr` format";
1059
1060 return parser.parseRegion(region, entryBlockArgs);
1061}
1062
1063// These parseXyz functions correspond to the custom<Xyz> definitions
1064// in the .td file(s).
1065static ParseResult parseTargetOpRegion(
1066 OpAsmParser &parser, Region &region,
1068 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1070 SmallVectorImpl<Type> &hostEvalTypes,
1072 SmallVectorImpl<Type> &inReductionTypes,
1073 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1075 SmallVectorImpl<Type> &mapTypes,
1077 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1078 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1079 AllRegionParseArgs args;
1080 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1081 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1082 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1083 inReductionByref, inReductionSyms);
1084 args.mapArgs.emplace(mapVars, mapTypes);
1085 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1086 privateNeedsBarrier, &privateMaps);
1087 return parseBlockArgRegion(parser, region, args);
1088}
1089
1091 OpAsmParser &parser, Region &region,
1093 SmallVectorImpl<Type> &inReductionTypes,
1094 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1096 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1097 UnitAttr &privateNeedsBarrier) {
1098 AllRegionParseArgs args;
1099 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1100 inReductionByref, inReductionSyms);
1101 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1102 privateNeedsBarrier);
1103 return parseBlockArgRegion(parser, region, args);
1104}
1105
1107 OpAsmParser &parser, Region &region,
1109 SmallVectorImpl<Type> &inReductionTypes,
1110 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1112 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1113 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1115 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1116 ArrayAttr &reductionSyms) {
1117 AllRegionParseArgs args;
1118 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1119 inReductionByref, inReductionSyms);
1120 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1121 privateNeedsBarrier);
1122 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1123 reductionSyms, &reductionMod);
1124 return parseBlockArgRegion(parser, region, args);
1125}
1126
1127static ParseResult parsePrivateRegion(
1128 OpAsmParser &parser, Region &region,
1130 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1131 UnitAttr &privateNeedsBarrier) {
1132 AllRegionParseArgs args;
1133 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1134 privateNeedsBarrier);
1135 return parseBlockArgRegion(parser, region, args);
1136}
1137
1139 OpAsmParser &parser, Region &region,
1141 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1142 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1144 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1145 ArrayAttr &reductionSyms) {
1146 AllRegionParseArgs args;
1147 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1148 privateNeedsBarrier);
1149 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1150 reductionSyms, &reductionMod);
1151 return parseBlockArgRegion(parser, region, args);
1152}
1153
1154static ParseResult parseTaskReductionRegion(
1155 OpAsmParser &parser, Region &region,
1157 SmallVectorImpl<Type> &taskReductionTypes,
1158 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1159 AllRegionParseArgs args;
1160 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1161 taskReductionByref, taskReductionSyms);
1162 return parseBlockArgRegion(parser, region, args);
1163}
1164
1166 OpAsmParser &parser, Region &region,
1168 SmallVectorImpl<Type> &useDeviceAddrTypes,
1170 SmallVectorImpl<Type> &useDevicePtrTypes) {
1171 AllRegionParseArgs args;
1172 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1173 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1174 return parseBlockArgRegion(parser, region, args);
1175}
1176
1177//===----------------------------------------------------------------------===//
1178// Printers for operations including clauses that define entry block arguments.
1179//===----------------------------------------------------------------------===//
1180
1181namespace {
1182struct MapPrintArgs {
1183 ValueRange vars;
1184 TypeRange types;
1185 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1186};
1187struct PrivatePrintArgs {
1188 ValueRange vars;
1189 TypeRange types;
1190 ArrayAttr syms;
1191 UnitAttr needsBarrier;
1192 DenseI64ArrayAttr mapIndices;
1193 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1194 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1195 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1196 mapIndices(mapIndices) {}
1197};
1198struct ReductionPrintArgs {
1199 ValueRange vars;
1200 TypeRange types;
1201 DenseBoolArrayAttr byref;
1202 ArrayAttr syms;
1203 ReductionModifierAttr modifier;
1204 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1205 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1206 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1207};
1208struct AllRegionPrintArgs {
1209 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1210 std::optional<MapPrintArgs> hostEvalArgs;
1211 std::optional<ReductionPrintArgs> inReductionArgs;
1212 std::optional<MapPrintArgs> mapArgs;
1213 std::optional<PrivatePrintArgs> privateArgs;
1214 std::optional<ReductionPrintArgs> reductionArgs;
1215 std::optional<ReductionPrintArgs> taskReductionArgs;
1216 std::optional<MapPrintArgs> useDeviceAddrArgs;
1217 std::optional<MapPrintArgs> useDevicePtrArgs;
1218};
1219} // namespace
1220
1222 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1223 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1224 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1225 DenseBoolArrayAttr byref = nullptr,
1226 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1227 if (argsSubrange.empty())
1228 return;
1229
1230 p << clauseName << "(";
1231
1232 if (modifier)
1233 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1234
1235 if (!symbols) {
1236 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1237 symbols = ArrayAttr::get(ctx, values);
1238 }
1239
1240 if (!mapIndices) {
1241 llvm::SmallVector<int64_t> values(operands.size(), -1);
1242 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1243 }
1244
1245 if (!byref) {
1246 mlir::SmallVector<bool> values(operands.size(), false);
1247 byref = DenseBoolArrayAttr::get(ctx, values);
1248 }
1249
1250 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1251 mapIndices.asArrayRef(),
1252 byref.asArrayRef()),
1253 p, [&p](auto t) {
1254 auto [op, arg, sym, map, isByRef] = t;
1255 if (isByRef)
1256 p << "byref ";
1257 if (sym)
1258 p << sym << " ";
1259
1260 p << op << " -> " << arg;
1261
1262 if (map != -1)
1263 p << " [map_idx=" << map << "]";
1264 });
1265 p << " : ";
1266 llvm::interleaveComma(types, p);
1267 p << ") ";
1268
1269 if (needsBarrier)
1270 p << getPrivateNeedsBarrierSpelling() << " ";
1271}
1272
1274 StringRef clauseName, ValueRange argsSubrange,
1275 std::optional<MapPrintArgs> mapArgs) {
1276 if (mapArgs)
1277 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1278 mapArgs->types);
1279}
1280
1282 StringRef clauseName, ValueRange argsSubrange,
1283 std::optional<PrivatePrintArgs> privateArgs) {
1284 if (privateArgs)
1286 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1287 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1288 /*modifier=*/nullptr, privateArgs->needsBarrier);
1289}
1290
1291static void
1292printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1293 ValueRange argsSubrange,
1294 std::optional<ReductionPrintArgs> reductionArgs) {
1295 if (reductionArgs)
1296 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1297 reductionArgs->vars, reductionArgs->types,
1298 reductionArgs->syms, /*mapIndices=*/nullptr,
1299 reductionArgs->byref, reductionArgs->modifier);
1300}
1301
1303 const AllRegionPrintArgs &args) {
1304 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1305 MLIRContext *ctx = op->getContext();
1306
1307 printBlockArgClause(p, ctx, "has_device_addr",
1308 iface.getHasDeviceAddrBlockArgs(),
1309 args.hasDeviceAddrArgs);
1310 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1311 args.hostEvalArgs);
1312 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1313 args.inReductionArgs);
1314 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1315 args.mapArgs);
1316 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1317 args.privateArgs);
1318 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1319 args.reductionArgs);
1320 printBlockArgClause(p, ctx, "task_reduction",
1321 iface.getTaskReductionBlockArgs(),
1322 args.taskReductionArgs);
1323 printBlockArgClause(p, ctx, "use_device_addr",
1324 iface.getUseDeviceAddrBlockArgs(),
1325 args.useDeviceAddrArgs);
1326 printBlockArgClause(p, ctx, "use_device_ptr",
1327 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1328
1329 p.printRegion(region, /*printEntryBlockArgs=*/false);
1330}
1331
1332// These parseXyz functions correspond to the custom<Xyz> definitions
1333// in the .td file(s).
1335 OpAsmPrinter &p, Operation *op, Region &region,
1336 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1337 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1338 ValueRange inReductionVars, TypeRange inReductionTypes,
1339 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1340 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1341 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1342 DenseI64ArrayAttr privateMaps) {
1343 AllRegionPrintArgs args;
1344 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1345 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1346 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1347 inReductionByref, inReductionSyms);
1348 args.mapArgs.emplace(mapVars, mapTypes);
1349 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1350 privateNeedsBarrier, privateMaps);
1351 printBlockArgRegion(p, op, region, args);
1352}
1353
1355 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1356 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1357 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1358 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1359 AllRegionPrintArgs args;
1360 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1361 inReductionByref, inReductionSyms);
1362 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1363 privateNeedsBarrier,
1364 /*mapIndices=*/nullptr);
1365 printBlockArgRegion(p, op, region, args);
1366}
1367
1369 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1370 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1371 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1372 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1373 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1374 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1375 ArrayAttr reductionSyms) {
1376 AllRegionPrintArgs args;
1377 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1378 inReductionByref, inReductionSyms);
1379 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1380 privateNeedsBarrier,
1381 /*mapIndices=*/nullptr);
1382 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1383 reductionSyms, reductionMod);
1384 printBlockArgRegion(p, op, region, args);
1385}
1386
1388 ValueRange privateVars, TypeRange privateTypes,
1389 ArrayAttr privateSyms,
1390 UnitAttr privateNeedsBarrier) {
1391 AllRegionPrintArgs args;
1392 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1393 privateNeedsBarrier,
1394 /*mapIndices=*/nullptr);
1395 printBlockArgRegion(p, op, region, args);
1396}
1397
1399 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1400 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1401 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1402 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1403 ArrayAttr reductionSyms) {
1404 AllRegionPrintArgs args;
1405 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1406 privateNeedsBarrier,
1407 /*mapIndices=*/nullptr);
1408 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1409 reductionSyms, reductionMod);
1410 printBlockArgRegion(p, op, region, args);
1411}
1412
1414 Region &region,
1415 ValueRange taskReductionVars,
1416 TypeRange taskReductionTypes,
1417 DenseBoolArrayAttr taskReductionByref,
1418 ArrayAttr taskReductionSyms) {
1419 AllRegionPrintArgs args;
1420 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1421 taskReductionByref, taskReductionSyms);
1422 printBlockArgRegion(p, op, region, args);
1423}
1424
1426 Region &region,
1427 ValueRange useDeviceAddrVars,
1428 TypeRange useDeviceAddrTypes,
1429 ValueRange useDevicePtrVars,
1430 TypeRange useDevicePtrTypes) {
1431 AllRegionPrintArgs args;
1432 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1433 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1434 printBlockArgRegion(p, op, region, args);
1435}
1436
1437template <typename ParsePrefixFn>
1438static ParseResult parseSplitIteratedList(
1439 OpAsmParser &parser,
1441 SmallVectorImpl<Type> &iteratedTypes,
1443 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1444
1445 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1446 if (failed(parsePrefix()))
1447 return failure();
1448
1450 Type ty;
1451 if (parser.parseOperand(v) || parser.parseColonType(ty))
1452 return failure();
1453
1454 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1455 iteratedVars.push_back(v);
1456 iteratedTypes.push_back(ty);
1457 } else {
1458 plainVars.push_back(v);
1459 plainTypes.push_back(ty);
1460 }
1461 return success();
1462 });
1463}
1464
1465template <typename PrintPrefixFn>
1467 TypeRange iteratedTypes,
1468 ValueRange plainVars, TypeRange plainTypes,
1469 PrintPrefixFn &&printPrefixForPlain,
1470 PrintPrefixFn &&printPrefixForIterated) {
1471
1472 bool first = true;
1473 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1474 if (!first)
1475 p << ", ";
1476 printPrefix(v, t);
1477 p << v << " : " << t;
1478 first = false;
1479 };
1480
1481 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1482 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1483 for (unsigned i = 0; i < plainVars.size(); ++i)
1484 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1485}
1486
1487/// Verifies Reduction Clause
1488static LogicalResult
1489verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1490 OperandRange reductionVars,
1491 std::optional<ArrayRef<bool>> reductionByref) {
1492 if (!reductionVars.empty()) {
1493 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1494 return op->emitOpError()
1495 << "expected as many reduction symbol references "
1496 "as reduction variables";
1497 if (reductionByref && reductionByref->size() != reductionVars.size())
1498 return op->emitError() << "expected as many reduction variable by "
1499 "reference attributes as reduction variables";
1500 } else {
1501 if (reductionSyms)
1502 return op->emitOpError() << "unexpected reduction symbol references";
1503 return success();
1504 }
1505
1506 // TODO: The followings should be done in
1507 // SymbolUserOpInterface::verifySymbolUses.
1508 DenseSet<Value> accumulators;
1509 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1510 Value accum = std::get<0>(args);
1511
1512 if (!accumulators.insert(accum).second)
1513 return op->emitOpError() << "accumulator variable used more than once";
1514
1515 Type varType = accum.getType();
1516 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1517 auto decl =
1519 if (!decl)
1520 return op->emitOpError() << "expected symbol reference " << symbolRef
1521 << " to point to a reduction declaration";
1522
1523 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1524 return op->emitOpError()
1525 << "expected accumulator (" << varType
1526 << ") to be the same type as reduction declaration ("
1527 << decl.getAccumulatorType() << ")";
1528 }
1529
1530 return success();
1531}
1532
1533//===----------------------------------------------------------------------===//
1534// Parser, printer and verifier for Copyprivate
1535//===----------------------------------------------------------------------===//
1536
1537/// copyprivate-entry-list ::= copyprivate-entry
1538/// | copyprivate-entry-list `,` copyprivate-entry
1539/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1540static ParseResult parseCopyprivate(
1541 OpAsmParser &parser,
1543 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1545 if (failed(parser.parseCommaSeparatedList([&]() {
1546 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1547 parser.parseArrow() ||
1548 parser.parseAttribute(symsVec.emplace_back()) ||
1549 parser.parseColonType(copyprivateTypes.emplace_back()))
1550 return failure();
1551 return success();
1552 })))
1553 return failure();
1554 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1555 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1556 return success();
1557}
1558
1559/// Print Copyprivate clause
1561 OperandRange copyprivateVars,
1562 TypeRange copyprivateTypes,
1563 std::optional<ArrayAttr> copyprivateSyms) {
1564 if (!copyprivateSyms.has_value())
1565 return;
1566 llvm::interleaveComma(
1567 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1568 [&](const auto &args) {
1569 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1570 << std::get<2>(args);
1571 });
1572}
1573
1574/// Verifies CopyPrivate Clause
1575static LogicalResult
1577 std::optional<ArrayAttr> copyprivateSyms) {
1578 size_t copyprivateSymsSize =
1579 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1580 if (copyprivateSymsSize != copyprivateVars.size())
1581 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1582 << copyprivateVars.size()
1583 << ") and functions (= " << copyprivateSymsSize
1584 << "), both must be equal";
1585 if (!copyprivateSyms.has_value())
1586 return success();
1587
1588 for (auto copyprivateVarAndSym :
1589 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1590 auto symbolRef =
1591 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1592 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1593 funcOp;
1594 if (mlir::func::FuncOp mlirFuncOp =
1596 symbolRef))
1597 funcOp = mlirFuncOp;
1598 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1600 op, symbolRef))
1601 funcOp = llvmFuncOp;
1602
1603 auto getNumArguments = [&] {
1604 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1605 };
1606
1607 auto getArgumentType = [&](unsigned i) {
1608 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1609 *funcOp);
1610 };
1611
1612 if (!funcOp)
1613 return op->emitOpError() << "expected symbol reference " << symbolRef
1614 << " to point to a copy function";
1615
1616 if (getNumArguments() != 2)
1617 return op->emitOpError()
1618 << "expected copy function " << symbolRef << " to have 2 operands";
1619
1620 Type argTy = getArgumentType(0);
1621 if (argTy != getArgumentType(1))
1622 return op->emitOpError() << "expected copy function " << symbolRef
1623 << " arguments to have the same type";
1624
1625 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1626 if (argTy != varType)
1627 return op->emitOpError()
1628 << "expected copy function arguments' type (" << argTy
1629 << ") to be the same as copyprivate variable's type (" << varType
1630 << ")";
1631 }
1632
1633 return success();
1634}
1635
1636//===----------------------------------------------------------------------===//
1637// Parser, printer and verifier for DependVarList
1638//===----------------------------------------------------------------------===//
1639
1640/// depend-entry-list ::= depend-entry
1641/// | depend-entry-list `,` depend-entry
1642/// depend-entry ::= depend-kind `->` ssa-id `:` type
1643static ParseResult
1646 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1648 if (failed(parser.parseCommaSeparatedList([&]() {
1649 StringRef keyword;
1650 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1651 parser.parseOperand(dependVars.emplace_back()) ||
1652 parser.parseColonType(dependTypes.emplace_back()))
1653 return failure();
1654 if (std::optional<ClauseTaskDepend> keywordDepend =
1655 (symbolizeClauseTaskDepend(keyword)))
1656 kindsVec.emplace_back(
1657 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1658 else
1659 return failure();
1660 return success();
1661 })))
1662 return failure();
1663 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1664 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1665 return success();
1666}
1667
1668/// Print Depend clause
1670 OperandRange dependVars, TypeRange dependTypes,
1671 std::optional<ArrayAttr> dependKinds) {
1672
1673 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1674 if (i != 0)
1675 p << ", ";
1676 p << stringifyClauseTaskDepend(
1677 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1678 .getValue())
1679 << " -> " << dependVars[i] << " : " << dependTypes[i];
1680 }
1681}
1682
1683/// Verifies Depend clause
1684static LogicalResult verifyDependVarList(Operation *op,
1685 std::optional<ArrayAttr> dependKinds,
1686 OperandRange dependVars) {
1687 if (!dependVars.empty()) {
1688 if (!dependKinds || dependKinds->size() != dependVars.size())
1689 return op->emitOpError() << "expected as many depend values"
1690 " as depend variables";
1691 } else {
1692 if (dependKinds && !dependKinds->empty())
1693 return op->emitOpError() << "unexpected depend values";
1694 return success();
1695 }
1696
1697 return success();
1698}
1699
1700//===----------------------------------------------------------------------===//
1701// Parser, printer and verifier for Synchronization Hint (2.17.12)
1702//===----------------------------------------------------------------------===//
1703
1704/// Parses a Synchronization Hint clause. The value of hint is an integer
1705/// which is a combination of different hints from `omp_sync_hint_t`.
1706///
1707/// hint-clause = `hint` `(` hint-value `)`
1708static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1709 IntegerAttr &hintAttr) {
1710 StringRef hintKeyword;
1711 int64_t hint = 0;
1712 if (succeeded(parser.parseOptionalKeyword("none"))) {
1713 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1714 return success();
1715 }
1716 auto parseKeyword = [&]() -> ParseResult {
1717 if (failed(parser.parseKeyword(&hintKeyword)))
1718 return failure();
1719 if (hintKeyword == "uncontended")
1720 hint |= 1;
1721 else if (hintKeyword == "contended")
1722 hint |= 2;
1723 else if (hintKeyword == "nonspeculative")
1724 hint |= 4;
1725 else if (hintKeyword == "speculative")
1726 hint |= 8;
1727 else
1728 return parser.emitError(parser.getCurrentLocation())
1729 << hintKeyword << " is not a valid hint";
1730 return success();
1731 };
1732 if (parser.parseCommaSeparatedList(parseKeyword))
1733 return failure();
1734 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1735 return success();
1736}
1737
1738/// Prints a Synchronization Hint clause
1740 IntegerAttr hintAttr) {
1741 int64_t hint = hintAttr.getInt();
1742
1743 if (hint == 0) {
1744 p << "none";
1745 return;
1746 }
1747
1748 // Helper function to get n-th bit from the right end of `value`
1749 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1750
1751 bool uncontended = bitn(hint, 0);
1752 bool contended = bitn(hint, 1);
1753 bool nonspeculative = bitn(hint, 2);
1754 bool speculative = bitn(hint, 3);
1755
1757 if (uncontended)
1758 hints.push_back("uncontended");
1759 if (contended)
1760 hints.push_back("contended");
1761 if (nonspeculative)
1762 hints.push_back("nonspeculative");
1763 if (speculative)
1764 hints.push_back("speculative");
1765
1766 llvm::interleaveComma(hints, p);
1767}
1768
1769/// Verifies a synchronization hint clause
1770static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1771
1772 // Helper function to get n-th bit from the right end of `value`
1773 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1774
1775 bool uncontended = bitn(hint, 0);
1776 bool contended = bitn(hint, 1);
1777 bool nonspeculative = bitn(hint, 2);
1778 bool speculative = bitn(hint, 3);
1779
1780 if (uncontended && contended)
1781 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1782 "omp_sync_hint_contended cannot be combined";
1783 if (nonspeculative && speculative)
1784 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1785 "omp_sync_hint_speculative cannot be combined.";
1786 return success();
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// Parser, printer and verifier for Target
1791//===----------------------------------------------------------------------===//
1792
1793// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1794// boolean
1795static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1796 return (value & flag) == flag;
1797}
1798
1799/// Parses a map_entries map type from a string format back into its numeric
1800/// value.
1801///
1802/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1803/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1804static ParseResult parseMapClause(OpAsmParser &parser,
1805 ClauseMapFlagsAttr &mapType) {
1806 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1807 // This simply verifies the correct keyword is read in, the
1808 // keyword itself is stored inside of the operation
1809 auto parseTypeAndMod = [&]() -> ParseResult {
1810 StringRef mapTypeMod;
1811 if (parser.parseKeyword(&mapTypeMod))
1812 return failure();
1813
1814 if (mapTypeMod == "always")
1815 mapTypeBits |= ClauseMapFlags::always;
1816
1817 if (mapTypeMod == "implicit")
1818 mapTypeBits |= ClauseMapFlags::implicit;
1819
1820 if (mapTypeMod == "ompx_hold")
1821 mapTypeBits |= ClauseMapFlags::ompx_hold;
1822
1823 if (mapTypeMod == "close")
1824 mapTypeBits |= ClauseMapFlags::close;
1825
1826 if (mapTypeMod == "present")
1827 mapTypeBits |= ClauseMapFlags::present;
1828
1829 if (mapTypeMod == "to")
1830 mapTypeBits |= ClauseMapFlags::to;
1831
1832 if (mapTypeMod == "from")
1833 mapTypeBits |= ClauseMapFlags::from;
1834
1835 if (mapTypeMod == "tofrom")
1836 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1837
1838 if (mapTypeMod == "delete")
1839 mapTypeBits |= ClauseMapFlags::del;
1840
1841 if (mapTypeMod == "storage")
1842 mapTypeBits |= ClauseMapFlags::storage;
1843
1844 if (mapTypeMod == "return_param")
1845 mapTypeBits |= ClauseMapFlags::return_param;
1846
1847 if (mapTypeMod == "private")
1848 mapTypeBits |= ClauseMapFlags::priv;
1849
1850 if (mapTypeMod == "literal")
1851 mapTypeBits |= ClauseMapFlags::literal;
1852
1853 if (mapTypeMod == "attach")
1854 mapTypeBits |= ClauseMapFlags::attach;
1855
1856 if (mapTypeMod == "attach_always")
1857 mapTypeBits |= ClauseMapFlags::attach_always;
1858
1859 if (mapTypeMod == "attach_never")
1860 mapTypeBits |= ClauseMapFlags::attach_never;
1861
1862 if (mapTypeMod == "attach_auto")
1863 mapTypeBits |= ClauseMapFlags::attach_auto;
1864
1865 if (mapTypeMod == "ref_ptr")
1866 mapTypeBits |= ClauseMapFlags::ref_ptr;
1867
1868 if (mapTypeMod == "ref_ptee")
1869 mapTypeBits |= ClauseMapFlags::ref_ptee;
1870
1871 if (mapTypeMod == "ref_ptr_ptee")
1872 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1873
1874 if (mapTypeMod == "is_device_ptr")
1875 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1876
1877 return success();
1878 };
1879
1880 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1881 return failure();
1882
1883 mapType =
1884 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1885
1886 return success();
1887}
1888
1889/// Prints a map_entries map type from its numeric value out into its string
1890/// format.
1891static void printMapClause(OpAsmPrinter &p, Operation *op,
1892 ClauseMapFlagsAttr mapType) {
1894 ClauseMapFlags mapFlags = mapType.getValue();
1895
1896 // handling of always, close, present placed at the beginning of the string
1897 // to aid readability
1898 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
1899 mapTypeStrs.push_back("always");
1900 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
1901 mapTypeStrs.push_back("implicit");
1902 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
1903 mapTypeStrs.push_back("ompx_hold");
1904 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
1905 mapTypeStrs.push_back("close");
1906 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
1907 mapTypeStrs.push_back("present");
1908
1909 // special handling of to/from/tofrom/delete and release/alloc, release +
1910 // alloc are the abscense of one of the other flags, whereas tofrom requires
1911 // both the to and from flag to be set.
1912 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
1913 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
1914
1915 if (to && from)
1916 mapTypeStrs.push_back("tofrom");
1917 else if (from)
1918 mapTypeStrs.push_back("from");
1919 else if (to)
1920 mapTypeStrs.push_back("to");
1921
1922 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
1923 mapTypeStrs.push_back("delete");
1924 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
1925 mapTypeStrs.push_back("return_param");
1926 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
1927 mapTypeStrs.push_back("storage");
1928 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
1929 mapTypeStrs.push_back("private");
1930 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
1931 mapTypeStrs.push_back("literal");
1932 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
1933 mapTypeStrs.push_back("attach");
1934 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
1935 mapTypeStrs.push_back("attach_always");
1936 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
1937 mapTypeStrs.push_back("attach_never");
1938 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
1939 mapTypeStrs.push_back("attach_auto");
1940 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
1941 mapTypeStrs.push_back("ref_ptr");
1942 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
1943 mapTypeStrs.push_back("ref_ptee");
1944 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
1945 mapTypeStrs.push_back("ref_ptr_ptee");
1946 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1947 mapTypeStrs.push_back("is_device_ptr");
1948 if (mapFlags == ClauseMapFlags::none)
1949 mapTypeStrs.push_back("none");
1950
1951 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1952 p << mapTypeStrs[i];
1953 if (i + 1 < mapTypeStrs.size()) {
1954 p << ", ";
1955 }
1956 }
1957}
1958
1959static ParseResult parseMembersIndex(OpAsmParser &parser,
1960 ArrayAttr &membersIdx) {
1961 SmallVector<Attribute> values, memberIdxs;
1962
1963 auto parseIndices = [&]() -> ParseResult {
1964 int64_t value;
1965 if (parser.parseInteger(value))
1966 return failure();
1967 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1968 APInt(64, value, /*isSigned=*/false)));
1969 return success();
1970 };
1971
1972 do {
1973 if (failed(parser.parseLSquare()))
1974 return failure();
1975
1976 if (parser.parseCommaSeparatedList(parseIndices))
1977 return failure();
1978
1979 if (failed(parser.parseRSquare()))
1980 return failure();
1981
1982 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1983 values.clear();
1984 } while (succeeded(parser.parseOptionalComma()));
1985
1986 if (!memberIdxs.empty())
1987 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1988
1989 return success();
1990}
1991
1992static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1993 ArrayAttr membersIdx) {
1994 if (!membersIdx)
1995 return;
1996
1997 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1998 p << "[";
1999 auto memberIdx = cast<ArrayAttr>(v);
2000 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2001 p << cast<IntegerAttr>(v2).getInt();
2002 });
2003 p << "]";
2004 });
2005}
2006
2008 VariableCaptureKindAttr mapCaptureType) {
2009 std::string typeCapStr;
2010 llvm::raw_string_ostream typeCap(typeCapStr);
2011 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2012 typeCap << "ByRef";
2013 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2014 typeCap << "ByCopy";
2015 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2016 typeCap << "VLAType";
2017 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2018 typeCap << "This";
2019 p << typeCapStr;
2020}
2021
2022static ParseResult parseCaptureType(OpAsmParser &parser,
2023 VariableCaptureKindAttr &mapCaptureType) {
2024 StringRef mapCaptureKey;
2025 if (parser.parseKeyword(&mapCaptureKey))
2026 return failure();
2027
2028 if (mapCaptureKey == "This")
2029 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2030 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2031 if (mapCaptureKey == "ByRef")
2032 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2033 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2034 if (mapCaptureKey == "ByCopy")
2035 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2036 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2037 if (mapCaptureKey == "VLAType")
2038 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2039 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2040
2041 return success();
2042}
2043
2044static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2047
2048 for (auto mapOp : mapVars) {
2049 if (!mapOp.getDefiningOp())
2050 return emitError(op->getLoc(), "missing map operation");
2051
2052 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2053 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2054
2055 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2056 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2057 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2058
2059 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2060 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2061 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2062
2063 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2064 return emitError(op->getLoc(),
2065 "to, from, tofrom and alloc map types are permitted");
2066
2067 if (isa<TargetEnterDataOp>(op) && (from || del))
2068 return emitError(op->getLoc(), "to and alloc map types are permitted");
2069
2070 if (isa<TargetExitDataOp>(op) && to)
2071 return emitError(op->getLoc(),
2072 "from, release and delete map types are permitted");
2073
2074 if (isa<TargetUpdateOp>(op)) {
2075 if (del) {
2076 return emitError(op->getLoc(),
2077 "at least one of to or from map types must be "
2078 "specified, other map types are not permitted");
2079 }
2080
2081 if (!to && !from) {
2082 return emitError(op->getLoc(),
2083 "at least one of to or from map types must be "
2084 "specified, other map types are not permitted");
2085 }
2086
2087 auto updateVar = mapInfoOp.getVarPtr();
2088
2089 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2090 (from && updateToVars.contains(updateVar))) {
2091 return emitError(
2092 op->getLoc(),
2093 "either to or from map types can be specified, not both");
2094 }
2095
2096 if (always || close || implicit) {
2097 return emitError(
2098 op->getLoc(),
2099 "present, mapper and iterator map type modifiers are permitted");
2100 }
2101
2102 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2103 }
2104 } else if (!isa<DeclareMapperInfoOp>(op)) {
2105 return emitError(op->getLoc(),
2106 "map argument is not a map entry operation");
2107 }
2108 }
2109
2110 return success();
2111}
2112
2113static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2114 std::optional<DenseI64ArrayAttr> privateMapIndices =
2115 targetOp.getPrivateMapsAttr();
2116
2117 // None of the private operands are mapped.
2118 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2119 return success();
2120
2121 OperandRange privateVars = targetOp.getPrivateVars();
2122
2123 if (privateMapIndices.value().size() !=
2124 static_cast<int64_t>(privateVars.size()))
2125 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2126 "`private_maps` attribute mismatch");
2127
2128 return success();
2129}
2130
2131//===----------------------------------------------------------------------===//
2132// MapInfoOp
2133//===----------------------------------------------------------------------===//
2134
2135static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2136 StringRef clauseName,
2137 OperandRange vars) {
2138 for (Value var : vars)
2139 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2140 return op->emitOpError()
2141 << "'" << clauseName
2142 << "' arguments must be defined by 'omp.map.info' ops";
2143 return success();
2144}
2145
2146LogicalResult MapInfoOp::verify() {
2147 if (getMapperId() &&
2149 *this, getMapperIdAttr())) {
2150 return emitError("invalid mapper id");
2151 }
2152
2153 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2154 return failure();
2155
2156 return success();
2157}
2158
2159//===----------------------------------------------------------------------===//
2160// TargetDataOp
2161//===----------------------------------------------------------------------===//
2162
2163void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2164 const TargetDataOperands &clauses) {
2165 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2166 clauses.mapVars, clauses.useDeviceAddrVars,
2167 clauses.useDevicePtrVars);
2168}
2169
2170LogicalResult TargetDataOp::verify() {
2171 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2172 getUseDeviceAddrVars().empty()) {
2173 return ::emitError(this->getLoc(),
2174 "At least one of map, use_device_ptr_vars, or "
2175 "use_device_addr_vars operand must be present");
2176 }
2177
2178 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2179 getUseDevicePtrVars())))
2180 return failure();
2181
2182 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2183 getUseDeviceAddrVars())))
2184 return failure();
2185
2186 return verifyMapClause(*this, getMapVars());
2187}
2188
2189//===----------------------------------------------------------------------===//
2190// TargetEnterDataOp
2191//===----------------------------------------------------------------------===//
2192
2193void TargetEnterDataOp::build(
2194 OpBuilder &builder, OperationState &state,
2195 const TargetEnterExitUpdateDataOperands &clauses) {
2196 MLIRContext *ctx = builder.getContext();
2197 TargetEnterDataOp::build(builder, state,
2198 makeArrayAttr(ctx, clauses.dependKinds),
2199 clauses.dependVars, clauses.device, clauses.ifExpr,
2200 clauses.mapVars, clauses.nowait);
2201}
2202
2203LogicalResult TargetEnterDataOp::verify() {
2204 LogicalResult verifyDependVars =
2205 verifyDependVarList(*this, getDependKinds(), getDependVars());
2206 return failed(verifyDependVars) ? verifyDependVars
2207 : verifyMapClause(*this, getMapVars());
2208}
2209
2210//===----------------------------------------------------------------------===//
2211// TargetExitDataOp
2212//===----------------------------------------------------------------------===//
2213
2214void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2215 const TargetEnterExitUpdateDataOperands &clauses) {
2216 MLIRContext *ctx = builder.getContext();
2217 TargetExitDataOp::build(builder, state,
2218 makeArrayAttr(ctx, clauses.dependKinds),
2219 clauses.dependVars, clauses.device, clauses.ifExpr,
2220 clauses.mapVars, clauses.nowait);
2221}
2222
2223LogicalResult TargetExitDataOp::verify() {
2224 LogicalResult verifyDependVars =
2225 verifyDependVarList(*this, getDependKinds(), getDependVars());
2226 return failed(verifyDependVars) ? verifyDependVars
2227 : verifyMapClause(*this, getMapVars());
2228}
2229
2230//===----------------------------------------------------------------------===//
2231// TargetUpdateOp
2232//===----------------------------------------------------------------------===//
2233
2234void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2235 const TargetEnterExitUpdateDataOperands &clauses) {
2236 MLIRContext *ctx = builder.getContext();
2237 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2238 clauses.dependVars, clauses.device, clauses.ifExpr,
2239 clauses.mapVars, clauses.nowait);
2240}
2241
2242LogicalResult TargetUpdateOp::verify() {
2243 LogicalResult verifyDependVars =
2244 verifyDependVarList(*this, getDependKinds(), getDependVars());
2245 return failed(verifyDependVars) ? verifyDependVars
2246 : verifyMapClause(*this, getMapVars());
2247}
2248
2249//===----------------------------------------------------------------------===//
2250// TargetOp
2251//===----------------------------------------------------------------------===//
2252
2253void TargetOp::build(OpBuilder &builder, OperationState &state,
2254 const TargetOperands &clauses) {
2255 MLIRContext *ctx = builder.getContext();
2256 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2257 // inReductionByref, inReductionSyms.
2258 TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2259 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
2260 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2261 clauses.hostEvalVars, clauses.ifExpr,
2262 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2263 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
2264 clauses.mapVars, clauses.nowait, clauses.privateVars,
2265 makeArrayAttr(ctx, clauses.privateSyms),
2266 clauses.privateNeedsBarrier, clauses.threadLimitVars,
2267 /*private_maps=*/nullptr);
2268}
2269
2270LogicalResult TargetOp::verify() {
2271 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
2272 return failure();
2273
2274 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2275 getHasDeviceAddrVars())))
2276 return failure();
2277
2278 if (failed(verifyMapClause(*this, getMapVars())))
2279 return failure();
2280
2281 return verifyPrivateVarsMapping(*this);
2282}
2283
2284LogicalResult TargetOp::verifyRegions() {
2285 auto teamsOps = getOps<TeamsOp>();
2286 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2287 return emitError("target containing multiple 'omp.teams' nested ops");
2288
2289 // Check that host_eval values are only used in legal ways.
2290 Operation *capturedOp = getInnermostCapturedOmpOp();
2291 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2292 for (Value hostEvalArg :
2293 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2294 for (Operation *user : hostEvalArg.getUsers()) {
2295 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2296 // Check if used in num_teams_lower or any of num_teams_upper_vars
2297 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2298 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2299 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2300 continue;
2301
2302 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2303 "and 'thread_limit' in 'omp.teams'";
2304 }
2305 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2306 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2307 parallelOp->isAncestor(capturedOp) &&
2308 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2309 continue;
2310
2311 return emitOpError()
2312 << "host_eval argument only legal as 'num_threads' in "
2313 "'omp.parallel' when representing target SPMD";
2314 }
2315 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2316 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2317 loopNestOp.getOperation() == capturedOp &&
2318 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2319 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2320 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2321 continue;
2322
2323 return emitOpError() << "host_eval argument only legal as loop bounds "
2324 "and steps in 'omp.loop_nest' when trip count "
2325 "must be evaluated in the host";
2326 }
2327
2328 return emitOpError() << "host_eval argument illegal use in '"
2329 << user->getName() << "' operation";
2330 }
2331 }
2332 return success();
2333}
2334
2335static Operation *
2336findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2337 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2338 assert(rootOp && "expected valid operation");
2339
2340 Dialect *ompDialect = rootOp->getDialect();
2341 Operation *capturedOp = nullptr;
2342 DominanceInfo domInfo;
2343
2344 // Process in pre-order to check operations from outermost to innermost,
2345 // ensuring we only enter the region of an operation if it meets the criteria
2346 // for being captured. We stop the exploration of nested operations as soon as
2347 // we process a region holding no operations to be captured.
2348 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2349 if (op == rootOp)
2350 return WalkResult::advance();
2351
2352 // Ignore operations of other dialects or omp operations with no regions,
2353 // because these will only be checked if they are siblings of an omp
2354 // operation that can potentially be captured.
2355 bool isOmpDialect = op->getDialect() == ompDialect;
2356 bool hasRegions = op->getNumRegions() > 0;
2357 if (!isOmpDialect || !hasRegions)
2358 return WalkResult::skip();
2359
2360 // This operation cannot be captured if it can be executed more than once
2361 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2362 // be executed before all exits of the region (i.e. it doesn't dominate all
2363 // blocks with no successors reachable from the entry block).
2364 if (checkSingleMandatoryExec) {
2365 Region *parentRegion = op->getParentRegion();
2366 Block *parentBlock = op->getBlock();
2367
2368 for (Block *successor : parentBlock->getSuccessors())
2369 if (successor->isReachable(parentBlock))
2370 return WalkResult::interrupt();
2371
2372 for (Block &block : *parentRegion)
2373 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2374 !domInfo.dominates(parentBlock, &block))
2375 return WalkResult::interrupt();
2376 }
2377
2378 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2379 // into nested operations.
2380 for (Operation &sibling : op->getParentRegion()->getOps())
2381 if (&sibling != op && !siblingAllowedFn(&sibling))
2382 return WalkResult::interrupt();
2383
2384 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2385 // Otherwise, process the contents of this operation.
2386 capturedOp = op;
2387 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2389 });
2390
2391 return capturedOp;
2392}
2393
2394Operation *TargetOp::getInnermostCapturedOmpOp() {
2395 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2396
2397 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2398 // effects, but don't include a memory write effect.
2399 return findCapturedOmpOp(
2400 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2401 if (!sibling)
2402 return false;
2403
2404 if (ompDialect == sibling->getDialect())
2405 return sibling->hasTrait<OpTrait::IsTerminator>();
2406
2407 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2409 effects;
2410 memOp.getEffects(effects);
2411 return !llvm::any_of(
2412 effects, [&](MemoryEffects::EffectInstance &effect) {
2413 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2414 isa<SideEffects::AutomaticAllocationScopeResource>(
2415 effect.getResource());
2416 });
2417 }
2418 return true;
2419 });
2420}
2421
2422/// Check if we can promote SPMD kernel to No-Loop kernel.
2423static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2424 WsloopOp *wsLoopOp) {
2425 // num_teams clause can break no-loop teams/threads assumption.
2426 if (!teamsOp.getNumTeamsUpperVars().empty())
2427 return false;
2428
2429 // Reduction kernels are slower in no-loop mode.
2430 if (teamsOp.getNumReductionVars())
2431 return false;
2432 if (wsLoopOp->getNumReductionVars())
2433 return false;
2434
2435 // Check if the user allows the promotion of kernels to no-loop mode.
2436 OffloadModuleInterface offloadMod =
2437 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2438 if (!offloadMod)
2439 return false;
2440 auto ompFlags = offloadMod.getFlags();
2441 if (!ompFlags)
2442 return false;
2443 return ompFlags.getAssumeTeamsOversubscription() &&
2444 ompFlags.getAssumeThreadsOversubscription();
2445}
2446
2447TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2448 // A non-null captured op is only valid if it resides inside of a TargetOp
2449 // and is the result of calling getInnermostCapturedOmpOp() on it.
2450 TargetOp targetOp =
2451 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2452 assert((!capturedOp ||
2453 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2454 "unexpected captured op");
2455
2456 // If it's not capturing a loop, it's a default target region.
2457 if (!isa_and_present<LoopNestOp>(capturedOp))
2458 return TargetRegionFlags::generic;
2459
2460 // Get the innermost non-simd loop wrapper.
2462 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2463 assert(!loopWrappers.empty());
2464
2465 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2466 if (isa<SimdOp>(innermostWrapper))
2467 innermostWrapper = std::next(innermostWrapper);
2468
2469 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2470 if (numWrappers != 1 && numWrappers != 2)
2471 return TargetRegionFlags::generic;
2472
2473 // Detect target-teams-distribute-parallel-wsloop[-simd].
2474 if (numWrappers == 2) {
2475 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2476 if (!wsloopOp)
2477 return TargetRegionFlags::generic;
2478
2479 innermostWrapper = std::next(innermostWrapper);
2480 if (!isa<DistributeOp>(innermostWrapper))
2481 return TargetRegionFlags::generic;
2482
2483 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2484 if (!isa_and_present<ParallelOp>(parallelOp))
2485 return TargetRegionFlags::generic;
2486
2487 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2488 if (!teamsOp)
2489 return TargetRegionFlags::generic;
2490
2491 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2492 TargetRegionFlags result =
2493 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2494 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2495 result = result | TargetRegionFlags::no_loop;
2496 return result;
2497 }
2498 }
2499 // Detect target-teams-distribute[-simd] and target-teams-loop.
2500 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2501 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2502 if (!isa_and_present<TeamsOp>(teamsOp))
2503 return TargetRegionFlags::generic;
2504
2505 if (teamsOp->getParentOp() != targetOp.getOperation())
2506 return TargetRegionFlags::generic;
2507
2508 if (isa<LoopOp>(innermostWrapper))
2509 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2510
2511 // Find single immediately nested captured omp.parallel and add spmd flag
2512 // (generic-spmd case).
2513 //
2514 // TODO: This shouldn't have to be done here, as it is too easy to break.
2515 // The openmp-opt pass should be updated to be able to promote kernels like
2516 // this from "Generic" to "Generic-SPMD". However, the use of the
2517 // `kmpc_distribute_static_loop` family of functions produced by the
2518 // OMPIRBuilder for these kernels prevents that from working.
2519 Dialect *ompDialect = targetOp->getDialect();
2520 Operation *nestedCapture = findCapturedOmpOp(
2521 capturedOp, /*checkSingleMandatoryExec=*/false,
2522 [&](Operation *sibling) {
2523 return sibling && (ompDialect != sibling->getDialect() ||
2524 sibling->hasTrait<OpTrait::IsTerminator>());
2525 });
2526
2527 TargetRegionFlags result =
2528 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2529
2530 if (!nestedCapture)
2531 return result;
2532
2533 while (nestedCapture->getParentOp() != capturedOp)
2534 nestedCapture = nestedCapture->getParentOp();
2535
2536 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2537 : result;
2538 }
2539 // Detect target-parallel-wsloop[-simd].
2540 else if (isa<WsloopOp>(innermostWrapper)) {
2541 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2542 if (!isa_and_present<ParallelOp>(parallelOp))
2543 return TargetRegionFlags::generic;
2544
2545 if (parallelOp->getParentOp() == targetOp.getOperation())
2546 return TargetRegionFlags::spmd;
2547 }
2548
2549 return TargetRegionFlags::generic;
2550}
2551
2552//===----------------------------------------------------------------------===//
2553// ParallelOp
2554//===----------------------------------------------------------------------===//
2555
2556void ParallelOp::build(OpBuilder &builder, OperationState &state,
2557 ArrayRef<NamedAttribute> attributes) {
2558 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2559 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2560 /*num_threads_vars=*/ValueRange(),
2561 /*private_vars=*/ValueRange(),
2562 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2563 /*proc_bind_kind=*/nullptr,
2564 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2565 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2566 state.addAttributes(attributes);
2567}
2568
2569void ParallelOp::build(OpBuilder &builder, OperationState &state,
2570 const ParallelOperands &clauses) {
2571 MLIRContext *ctx = builder.getContext();
2572 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2573 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2574 makeArrayAttr(ctx, clauses.privateSyms),
2575 clauses.privateNeedsBarrier, clauses.procBindKind,
2576 clauses.reductionMod, clauses.reductionVars,
2577 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2578 makeArrayAttr(ctx, clauses.reductionSyms));
2579}
2580
2581template <typename OpType>
2582static LogicalResult verifyPrivateVarList(OpType &op) {
2583 auto privateVars = op.getPrivateVars();
2584 auto privateSyms = op.getPrivateSymsAttr();
2585
2586 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2587 return success();
2588
2589 auto numPrivateVars = privateVars.size();
2590 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2591
2592 if (numPrivateVars != numPrivateSyms)
2593 return op.emitError() << "inconsistent number of private variables and "
2594 "privatizer op symbols, private vars: "
2595 << numPrivateVars
2596 << " vs. privatizer op symbols: " << numPrivateSyms;
2597
2598 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2599 Type varType = std::get<0>(privateVarInfo).getType();
2600 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2601 PrivateClauseOp privatizerOp =
2603
2604 if (privatizerOp == nullptr)
2605 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2606 << privateSym << "'";
2607
2608 Type privatizerType = privatizerOp.getArgType();
2609
2610 if (privatizerType && (varType != privatizerType))
2611 return op.emitError()
2612 << "type mismatch between a "
2613 << (privatizerOp.getDataSharingType() ==
2614 DataSharingClauseType::Private
2615 ? "private"
2616 : "firstprivate")
2617 << " variable and its privatizer op, var type: " << varType
2618 << " vs. privatizer op type: " << privatizerType;
2619 }
2620
2621 return success();
2622}
2623
2624LogicalResult ParallelOp::verify() {
2625 if (getAllocateVars().size() != getAllocatorVars().size())
2626 return emitError(
2627 "expected equal sizes for allocate and allocator variables");
2628
2629 if (failed(verifyPrivateVarList(*this)))
2630 return failure();
2631
2632 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2633 getReductionByref());
2634}
2635
2636LogicalResult ParallelOp::verifyRegions() {
2637 auto distChildOps = getOps<DistributeOp>();
2638 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2639 if (numDistChildOps > 1)
2640 return emitError()
2641 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2642
2643 if (numDistChildOps == 1) {
2644 if (!isComposite())
2645 return emitError()
2646 << "'omp.composite' attribute missing from composite operation";
2647
2648 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2649 Operation &distributeOp = **distChildOps.begin();
2650 for (Operation &childOp : getOps()) {
2651 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2652 continue;
2653
2654 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2655 return emitError() << "unexpected OpenMP operation inside of composite "
2656 "'omp.parallel': "
2657 << childOp.getName();
2658 }
2659 } else if (isComposite()) {
2660 return emitError()
2661 << "'omp.composite' attribute present in non-composite operation";
2662 }
2663 return success();
2664}
2665
2666//===----------------------------------------------------------------------===//
2667// TeamsOp
2668//===----------------------------------------------------------------------===//
2669
2671 while ((op = op->getParentOp()))
2672 if (isa<OpenMPDialect>(op->getDialect()))
2673 return false;
2674 return true;
2675}
2676
2677void TeamsOp::build(OpBuilder &builder, OperationState &state,
2678 const TeamsOperands &clauses) {
2679 MLIRContext *ctx = builder.getContext();
2680 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2681 TeamsOp::build(
2682 builder, state, clauses.allocateVars, clauses.allocatorVars,
2683 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2684 /*private_vars=*/{}, /*private_syms=*/nullptr,
2685 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2686 clauses.reductionVars,
2687 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2688 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2689}
2690
2691// Verify num_teams clause
2692static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2693 OperandRange numTeamsUpperVars) {
2694 // If lower is specified, upper must have exactly one value
2695 if (numTeamsLower) {
2696 if (numTeamsUpperVars.size() != 1)
2697 return op->emitError(
2698 "expected exactly one num_teams upper bound when lower bound is "
2699 "specified");
2700 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2701 return op->emitError(
2702 "expected num_teams upper bound and lower bound to be "
2703 "the same type");
2704 }
2705
2706 return success();
2707}
2708
2709LogicalResult TeamsOp::verify() {
2710 // Check parent region
2711 // TODO If nested inside of a target region, also check that it does not
2712 // contain any statements, declarations or directives other than this
2713 // omp.teams construct. The issue is how to support the initialization of
2714 // this operation's own arguments (allow SSA values across omp.target?).
2715 Operation *op = getOperation();
2716 if (!isa<TargetOp>(op->getParentOp()) &&
2718 return emitError("expected to be nested inside of omp.target or not nested "
2719 "in any OpenMP dialect operations");
2720
2721 // Check for num_teams clause restrictions
2722 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
2723 this->getNumTeamsUpperVars())))
2724 return failure();
2725
2726 // Check for allocate clause restrictions
2727 if (getAllocateVars().size() != getAllocatorVars().size())
2728 return emitError(
2729 "expected equal sizes for allocate and allocator variables");
2730
2731 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2732 getReductionByref());
2733}
2734
2735//===----------------------------------------------------------------------===//
2736// SectionOp
2737//===----------------------------------------------------------------------===//
2738
2739OperandRange SectionOp::getPrivateVars() {
2740 return getParentOp().getPrivateVars();
2741}
2742
2743OperandRange SectionOp::getReductionVars() {
2744 return getParentOp().getReductionVars();
2745}
2746
2747//===----------------------------------------------------------------------===//
2748// SectionsOp
2749//===----------------------------------------------------------------------===//
2750
2751void SectionsOp::build(OpBuilder &builder, OperationState &state,
2752 const SectionsOperands &clauses) {
2753 MLIRContext *ctx = builder.getContext();
2754 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2755 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2756 clauses.nowait, /*private_vars=*/{},
2757 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2758 clauses.reductionMod, clauses.reductionVars,
2759 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2760 makeArrayAttr(ctx, clauses.reductionSyms));
2761}
2762
2763LogicalResult SectionsOp::verify() {
2764 if (getAllocateVars().size() != getAllocatorVars().size())
2765 return emitError(
2766 "expected equal sizes for allocate and allocator variables");
2767
2768 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2769 getReductionByref());
2770}
2771
2772LogicalResult SectionsOp::verifyRegions() {
2773 for (auto &inst : *getRegion().begin()) {
2774 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2775 return emitOpError()
2776 << "expected omp.section op or terminator op inside region";
2777 }
2778 }
2779
2780 return success();
2781}
2782
2783//===----------------------------------------------------------------------===//
2784// SingleOp
2785//===----------------------------------------------------------------------===//
2786
2787void SingleOp::build(OpBuilder &builder, OperationState &state,
2788 const SingleOperands &clauses) {
2789 MLIRContext *ctx = builder.getContext();
2790 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2791 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2792 clauses.copyprivateVars,
2793 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2794 /*private_vars=*/{}, /*private_syms=*/nullptr,
2795 /*private_needs_barrier=*/nullptr);
2796}
2797
2798LogicalResult SingleOp::verify() {
2799 // Check for allocate clause restrictions
2800 if (getAllocateVars().size() != getAllocatorVars().size())
2801 return emitError(
2802 "expected equal sizes for allocate and allocator variables");
2803
2804 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2805 getCopyprivateSyms());
2806}
2807
2808//===----------------------------------------------------------------------===//
2809// WorkshareOp
2810//===----------------------------------------------------------------------===//
2811
2812void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2813 const WorkshareOperands &clauses) {
2814 WorkshareOp::build(builder, state, clauses.nowait);
2815}
2816
2817//===----------------------------------------------------------------------===//
2818// WorkshareLoopWrapperOp
2819//===----------------------------------------------------------------------===//
2820
2821LogicalResult WorkshareLoopWrapperOp::verify() {
2822 if (!(*this)->getParentOfType<WorkshareOp>())
2823 return emitOpError() << "must be nested in an omp.workshare";
2824 return success();
2825}
2826
2827LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2828 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2829 getNestedWrapper())
2830 return emitOpError() << "expected to be a standalone loop wrapper";
2831
2832 return success();
2833}
2834
2835//===----------------------------------------------------------------------===//
2836// LoopWrapperInterface
2837//===----------------------------------------------------------------------===//
2838
2839LogicalResult LoopWrapperInterface::verifyImpl() {
2840 Operation *op = this->getOperation();
2841 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2843 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2844 "and `SingleBlock` traits";
2845
2846 if (op->getNumRegions() != 1)
2847 return emitOpError() << "loop wrapper does not contain exactly one region";
2848
2849 Region &region = op->getRegion(0);
2850 if (range_size(region.getOps()) != 1)
2851 return emitOpError()
2852 << "loop wrapper does not contain exactly one nested op";
2853
2854 Operation &firstOp = *region.op_begin();
2855 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2856 return emitOpError() << "nested in loop wrapper is not another loop "
2857 "wrapper or `omp.loop_nest`";
2858
2859 return success();
2860}
2861
2862//===----------------------------------------------------------------------===//
2863// LoopOp
2864//===----------------------------------------------------------------------===//
2865
2866void LoopOp::build(OpBuilder &builder, OperationState &state,
2867 const LoopOperands &clauses) {
2868 MLIRContext *ctx = builder.getContext();
2869
2870 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2871 makeArrayAttr(ctx, clauses.privateSyms),
2872 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2873 clauses.reductionMod, clauses.reductionVars,
2874 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2875 makeArrayAttr(ctx, clauses.reductionSyms));
2876}
2877
2878LogicalResult LoopOp::verify() {
2879 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2880 getReductionByref());
2881}
2882
2883LogicalResult LoopOp::verifyRegions() {
2884 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2885 getNestedWrapper())
2886 return emitOpError() << "expected to be a standalone loop wrapper";
2887
2888 return success();
2889}
2890
2891//===----------------------------------------------------------------------===//
2892// WsloopOp
2893//===----------------------------------------------------------------------===//
2894
2895void WsloopOp::build(OpBuilder &builder, OperationState &state,
2896 ArrayRef<NamedAttribute> attributes) {
2897 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2898 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2899 /*linear_var_types*/ nullptr,
2900 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2901 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2902 /*private_needs_barrier=*/false,
2903 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2904 /*reduction_byref=*/nullptr,
2905 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2906 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2907 /*schedule_simd=*/false);
2908 state.addAttributes(attributes);
2909}
2910
2911void WsloopOp::build(OpBuilder &builder, OperationState &state,
2912 const WsloopOperands &clauses) {
2913 MLIRContext *ctx = builder.getContext();
2914 // TODO: Store clauses in op: allocateVars, allocatorVars
2915 WsloopOp::build(
2916 builder, state,
2917 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2918 clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait,
2919 clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars,
2920 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2921 clauses.reductionMod, clauses.reductionVars,
2922 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2923 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2924 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2925}
2926
2927LogicalResult WsloopOp::verify() {
2928 if (getLinearVars().size() &&
2929 getLinearVarTypes().value().size() != getLinearVars().size())
2930 return emitError() << "Ill-formed type attributes for linear variables";
2931 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2932 getReductionByref());
2933}
2934
2935LogicalResult WsloopOp::verifyRegions() {
2936 bool isCompositeChildLeaf =
2937 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2938
2939 if (LoopWrapperInterface nested = getNestedWrapper()) {
2940 if (!isComposite())
2941 return emitError()
2942 << "'omp.composite' attribute missing from composite wrapper";
2943
2944 // Check for the allowed leaf constructs that may appear in a composite
2945 // construct directly after DO/FOR.
2946 if (!isa<SimdOp>(nested))
2947 return emitError() << "only supported nested wrapper is 'omp.simd'";
2948
2949 } else if (isComposite() && !isCompositeChildLeaf) {
2950 return emitError()
2951 << "'omp.composite' attribute present in non-composite wrapper";
2952 } else if (!isComposite() && isCompositeChildLeaf) {
2953 return emitError()
2954 << "'omp.composite' attribute missing from composite wrapper";
2955 }
2956
2957 return success();
2958}
2959
2960//===----------------------------------------------------------------------===//
2961// Simd construct [2.9.3.1]
2962//===----------------------------------------------------------------------===//
2963
2964void SimdOp::build(OpBuilder &builder, OperationState &state,
2965 const SimdOperands &clauses) {
2966 MLIRContext *ctx = builder.getContext();
2967 SimdOp::build(
2968 builder, state, clauses.alignedVars,
2969 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2970 clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes,
2971 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2972 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2973 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2974 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2975 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2976 clauses.simdlen);
2977}
2978
2979LogicalResult SimdOp::verify() {
2980 if (getSimdlen().has_value() && getSafelen().has_value() &&
2981 getSimdlen().value() > getSafelen().value())
2982 return emitOpError()
2983 << "simdlen clause and safelen clause are both present, but the "
2984 "simdlen value is not less than or equal to safelen value";
2985
2986 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2987 return failure();
2988
2989 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2990 return failure();
2991
2992 bool isCompositeChildLeaf =
2993 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2994
2995 if (!isComposite() && isCompositeChildLeaf)
2996 return emitError()
2997 << "'omp.composite' attribute missing from composite wrapper";
2998
2999 if (isComposite() && !isCompositeChildLeaf)
3000 return emitError()
3001 << "'omp.composite' attribute present in non-composite wrapper";
3002
3003 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3004 // the private decls are for firstprivate.
3005 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3006 if (privateSyms) {
3007 for (const Attribute &sym : *privateSyms) {
3008 auto symRef = cast<SymbolRefAttr>(sym);
3009 omp::PrivateClauseOp privatizer =
3011 getOperation(), symRef);
3012 if (!privatizer)
3013 return emitError() << "Cannot find privatizer '" << symRef << "'";
3014 if (privatizer.getDataSharingType() ==
3015 DataSharingClauseType::FirstPrivate)
3016 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3017 }
3018 }
3019
3020 if (getLinearVars().size() &&
3021 getLinearVarTypes().value().size() != getLinearVars().size())
3022 return emitError() << "Ill-formed type attributes for linear variables";
3023 return success();
3024}
3025
3026LogicalResult SimdOp::verifyRegions() {
3027 if (getNestedWrapper())
3028 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3029
3030 return success();
3031}
3032
3033//===----------------------------------------------------------------------===//
3034// Distribute construct [2.9.4.1]
3035//===----------------------------------------------------------------------===//
3036
3037void DistributeOp::build(OpBuilder &builder, OperationState &state,
3038 const DistributeOperands &clauses) {
3039 DistributeOp::build(builder, state, clauses.allocateVars,
3040 clauses.allocatorVars, clauses.distScheduleStatic,
3041 clauses.distScheduleChunkSize, clauses.order,
3042 clauses.orderMod, clauses.privateVars,
3043 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3044 clauses.privateNeedsBarrier);
3045}
3046
3047LogicalResult DistributeOp::verify() {
3048 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3049 return emitOpError() << "chunk size set without "
3050 "dist_schedule_static being present";
3051
3052 if (getAllocateVars().size() != getAllocatorVars().size())
3053 return emitError(
3054 "expected equal sizes for allocate and allocator variables");
3055
3056 return success();
3057}
3058
3059LogicalResult DistributeOp::verifyRegions() {
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3061 if (!isComposite())
3062 return emitError()
3063 << "'omp.composite' attribute missing from composite wrapper";
3064 // Check for the allowed leaf constructs that may appear in a composite
3065 // construct directly after DISTRIBUTE.
3066 if (isa<WsloopOp>(nested)) {
3067 Operation *parentOp = (*this)->getParentOp();
3068 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3069 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3070 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3071 "when a composite 'omp.parallel' is the direct "
3072 "parent";
3073 }
3074 } else if (!isa<SimdOp>(nested))
3075 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3076 "'omp.wsloop'";
3077 } else if (isComposite()) {
3078 return emitError()
3079 << "'omp.composite' attribute present in non-composite wrapper";
3080 }
3081
3082 return success();
3083}
3084
3085//===----------------------------------------------------------------------===//
3086// DeclareMapperOp / DeclareMapperInfoOp
3087//===----------------------------------------------------------------------===//
3088
3089LogicalResult DeclareMapperInfoOp::verify() {
3090 return verifyMapClause(*this, getMapVars());
3091}
3092
3093LogicalResult DeclareMapperOp::verifyRegions() {
3094 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3095 getRegion().getBlocks().front().getTerminator()))
3096 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3097
3098 return success();
3099}
3100
3101//===----------------------------------------------------------------------===//
3102// DeclareReductionOp
3103//===----------------------------------------------------------------------===//
3104
3105LogicalResult DeclareReductionOp::verifyRegions() {
3106 if (!getAllocRegion().empty()) {
3107 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3108 if (yieldOp.getResults().size() != 1 ||
3109 yieldOp.getResults().getTypes()[0] != getType())
3110 return emitOpError() << "expects alloc region to yield a value "
3111 "of the reduction type";
3112 }
3113 }
3114
3115 if (getInitializerRegion().empty())
3116 return emitOpError() << "expects non-empty initializer region";
3117 Block &initializerEntryBlock = getInitializerRegion().front();
3118
3119 if (initializerEntryBlock.getNumArguments() == 1) {
3120 if (!getAllocRegion().empty())
3121 return emitOpError() << "expects two arguments to the initializer region "
3122 "when an allocation region is used";
3123 } else if (initializerEntryBlock.getNumArguments() == 2) {
3124 if (getAllocRegion().empty())
3125 return emitOpError() << "expects one argument to the initializer region "
3126 "when no allocation region is used";
3127 } else {
3128 return emitOpError()
3129 << "expects one or two arguments to the initializer region";
3130 }
3131
3132 for (mlir::Value arg : initializerEntryBlock.getArguments())
3133 if (arg.getType() != getType())
3134 return emitOpError() << "expects initializer region argument to match "
3135 "the reduction type";
3136
3137 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3138 if (yieldOp.getResults().size() != 1 ||
3139 yieldOp.getResults().getTypes()[0] != getType())
3140 return emitOpError() << "expects initializer region to yield a value "
3141 "of the reduction type";
3142 }
3143
3144 if (getReductionRegion().empty())
3145 return emitOpError() << "expects non-empty reduction region";
3146 Block &reductionEntryBlock = getReductionRegion().front();
3147 if (reductionEntryBlock.getNumArguments() != 2 ||
3148 reductionEntryBlock.getArgumentTypes()[0] !=
3149 reductionEntryBlock.getArgumentTypes()[1] ||
3150 reductionEntryBlock.getArgumentTypes()[0] != getType())
3151 return emitOpError() << "expects reduction region with two arguments of "
3152 "the reduction type";
3153 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3154 if (yieldOp.getResults().size() != 1 ||
3155 yieldOp.getResults().getTypes()[0] != getType())
3156 return emitOpError() << "expects reduction region to yield a value "
3157 "of the reduction type";
3158 }
3159
3160 if (!getAtomicReductionRegion().empty()) {
3161 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3162 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3163 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3164 atomicReductionEntryBlock.getArgumentTypes()[1])
3165 return emitOpError() << "expects atomic reduction region with two "
3166 "arguments of the same type";
3167 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3168 atomicReductionEntryBlock.getArgumentTypes()[0]);
3169 if (!ptrType ||
3170 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3171 return emitOpError() << "expects atomic reduction region arguments to "
3172 "be accumulators containing the reduction type";
3173 }
3174
3175 if (getCleanupRegion().empty())
3176 return success();
3177 Block &cleanupEntryBlock = getCleanupRegion().front();
3178 if (cleanupEntryBlock.getNumArguments() != 1 ||
3179 cleanupEntryBlock.getArgument(0).getType() != getType())
3180 return emitOpError() << "expects cleanup region with one argument "
3181 "of the reduction type";
3182
3183 return success();
3184}
3185
3186//===----------------------------------------------------------------------===//
3187// TaskOp
3188//===----------------------------------------------------------------------===//
3189
3190void TaskOp::build(OpBuilder &builder, OperationState &state,
3191 const TaskOperands &clauses) {
3192 MLIRContext *ctx = builder.getContext();
3193 TaskOp::build(builder, state, clauses.iterated, clauses.affinityVars,
3194 clauses.allocateVars, clauses.allocatorVars,
3195 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3196 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3197 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3198 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3199 clauses.priority, /*private_vars=*/clauses.privateVars,
3200 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3201 clauses.privateNeedsBarrier, clauses.untied,
3202 clauses.eventHandle);
3203}
3204
3205LogicalResult TaskOp::verify() {
3206 LogicalResult verifyDependVars =
3207 verifyDependVarList(*this, getDependKinds(), getDependVars());
3208 return failed(verifyDependVars)
3209 ? verifyDependVars
3210 : verifyReductionVarList(*this, getInReductionSyms(),
3211 getInReductionVars(),
3212 getInReductionByref());
3213}
3214
3215//===----------------------------------------------------------------------===//
3216// TaskgroupOp
3217//===----------------------------------------------------------------------===//
3218
3219void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3220 const TaskgroupOperands &clauses) {
3221 MLIRContext *ctx = builder.getContext();
3222 TaskgroupOp::build(builder, state, clauses.allocateVars,
3223 clauses.allocatorVars, clauses.taskReductionVars,
3224 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3225 makeArrayAttr(ctx, clauses.taskReductionSyms));
3226}
3227
3228LogicalResult TaskgroupOp::verify() {
3229 return verifyReductionVarList(*this, getTaskReductionSyms(),
3230 getTaskReductionVars(),
3231 getTaskReductionByref());
3232}
3233
3234//===----------------------------------------------------------------------===//
3235// TaskloopOp
3236//===----------------------------------------------------------------------===//
3237
3238void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3239 const TaskloopOperands &clauses) {
3240 MLIRContext *ctx = builder.getContext();
3241 TaskloopOp::build(
3242 builder, state, clauses.allocateVars, clauses.allocatorVars,
3243 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3244 clauses.inReductionVars,
3245 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3246 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3247 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3248 /*private_vars=*/clauses.privateVars,
3249 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3250 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3251 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3252 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3253}
3254
3255LogicalResult TaskloopOp::verify() {
3256 if (getAllocateVars().size() != getAllocatorVars().size())
3257 return emitError(
3258 "expected equal sizes for allocate and allocator variables");
3259 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3260 getReductionVars(), getReductionByref())) ||
3261 failed(verifyReductionVarList(*this, getInReductionSyms(),
3262 getInReductionVars(),
3263 getInReductionByref())))
3264 return failure();
3265
3266 if (!getReductionVars().empty() && getNogroup())
3267 return emitError("if a reduction clause is present on the taskloop "
3268 "directive, the nogroup clause must not be specified");
3269 for (auto var : getReductionVars()) {
3270 if (llvm::is_contained(getInReductionVars(), var))
3271 return emitError("the same list item cannot appear in both a reduction "
3272 "and an in_reduction clause");
3273 }
3274
3275 if (getGrainsize() && getNumTasks()) {
3276 return emitError(
3277 "the grainsize clause and num_tasks clause are mutually exclusive and "
3278 "may not appear on the same taskloop directive");
3279 }
3280
3281 return success();
3282}
3283
3284LogicalResult TaskloopOp::verifyRegions() {
3285 if (LoopWrapperInterface nested = getNestedWrapper()) {
3286 if (!isComposite())
3287 return emitError()
3288 << "'omp.composite' attribute missing from composite wrapper";
3289
3290 // Check for the allowed leaf constructs that may appear in a composite
3291 // construct directly after TASKLOOP.
3292 if (!isa<SimdOp>(nested))
3293 return emitError() << "only supported nested wrapper is 'omp.simd'";
3294 } else if (isComposite()) {
3295 return emitError()
3296 << "'omp.composite' attribute present in non-composite wrapper";
3297 }
3298
3299 return success();
3300}
3301
3302//===----------------------------------------------------------------------===//
3303// LoopNestOp
3304//===----------------------------------------------------------------------===//
3305
3306ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3307 // Parse an opening `(` followed by induction variables followed by `)`
3310 Type loopVarType;
3312 parser.parseColonType(loopVarType) ||
3313 // Parse loop bounds.
3314 parser.parseEqual() ||
3315 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3316 parser.parseKeyword("to") ||
3317 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3318 return failure();
3319
3320 for (auto &iv : ivs)
3321 iv.type = loopVarType;
3322
3323 auto *ctx = parser.getBuilder().getContext();
3324 // Parse "inclusive" flag.
3325 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3326 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3327
3328 // Parse step values.
3330 if (parser.parseKeyword("step") ||
3331 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3332 return failure();
3333
3334 // Parse collapse
3335 int64_t value = 0;
3336 if (!parser.parseOptionalKeyword("collapse") &&
3337 (parser.parseLParen() || parser.parseInteger(value) ||
3338 parser.parseRParen()))
3339 return failure();
3340 if (value > 1)
3341 result.addAttribute(
3342 "collapse_num_loops",
3343 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3344
3345 // Parse tiles
3347 auto parseTiles = [&]() -> ParseResult {
3348 int64_t tile;
3349 if (parser.parseInteger(tile))
3350 return failure();
3351 tiles.push_back(tile);
3352 return success();
3353 };
3354
3355 if (!parser.parseOptionalKeyword("tiles") &&
3356 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3357 parser.parseRParen()))
3358 return failure();
3359
3360 if (tiles.size() > 0)
3361 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3362
3363 // Parse the body.
3364 Region *region = result.addRegion();
3365 if (parser.parseRegion(*region, ivs))
3366 return failure();
3367
3368 // Resolve operands.
3369 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3370 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3371 parser.resolveOperands(steps, loopVarType, result.operands))
3372 return failure();
3373
3374 // Parse the optional attribute list.
3375 return parser.parseOptionalAttrDict(result.attributes);
3376}
3377
3378void LoopNestOp::print(OpAsmPrinter &p) {
3379 Region &region = getRegion();
3380 auto args = region.getArguments();
3381 p << " (" << args << ") : " << args[0].getType() << " = ("
3382 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3383 if (getLoopInclusive())
3384 p << "inclusive ";
3385 p << "step (" << getLoopSteps() << ") ";
3386 if (int64_t numCollapse = getCollapseNumLoops())
3387 if (numCollapse > 1)
3388 p << "collapse(" << numCollapse << ") ";
3389
3390 if (const auto tiles = getTileSizes())
3391 p << "tiles(" << tiles.value() << ") ";
3392
3393 p.printRegion(region, /*printEntryBlockArgs=*/false);
3394}
3395
3396void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3397 const LoopNestOperands &clauses) {
3398 MLIRContext *ctx = builder.getContext();
3399 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3400 clauses.loopLowerBounds, clauses.loopUpperBounds,
3401 clauses.loopSteps, clauses.loopInclusive,
3402 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3403}
3404
3405LogicalResult LoopNestOp::verify() {
3406 if (getLoopLowerBounds().empty())
3407 return emitOpError() << "must represent at least one loop";
3408
3409 if (getLoopLowerBounds().size() != getIVs().size())
3410 return emitOpError() << "number of range arguments and IVs do not match";
3411
3412 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3413 if (lb.getType() != iv.getType())
3414 return emitOpError()
3415 << "range argument type does not match corresponding IV type";
3416 }
3417
3418 uint64_t numIVs = getIVs().size();
3419
3420 if (const auto &numCollapse = getCollapseNumLoops())
3421 if (numCollapse > numIVs)
3422 return emitOpError()
3423 << "collapse value is larger than the number of loops";
3424
3425 if (const auto &tiles = getTileSizes())
3426 if (tiles.value().size() > numIVs)
3427 return emitOpError() << "too few canonical loops for tile dimensions";
3428
3429 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3430 return emitOpError() << "expects parent op to be a loop wrapper";
3431
3432 return success();
3433}
3434
3435void LoopNestOp::gatherWrappers(
3437 Operation *parent = (*this)->getParentOp();
3438 while (auto wrapper =
3439 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3440 wrappers.push_back(wrapper);
3441 parent = parent->getParentOp();
3442 }
3443}
3444
3445//===----------------------------------------------------------------------===//
3446// OpenMP canonical loop handling
3447//===----------------------------------------------------------------------===//
3448
3449std::tuple<NewCliOp, OpOperand *, OpOperand *>
3450mlir::omp ::decodeCli(Value cli) {
3451
3452 // Defining a CLI for a generated loop is optional; if there is none then
3453 // there is no followup-tranformation
3454 if (!cli)
3455 return {{}, nullptr, nullptr};
3456
3457 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3458 "Unexpected type of cli");
3459
3460 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3461 OpOperand *gen = nullptr;
3462 OpOperand *cons = nullptr;
3463 for (OpOperand &use : cli.getUses()) {
3464 auto op = cast<LoopTransformationInterface>(use.getOwner());
3465
3466 unsigned opnum = use.getOperandNumber();
3467 if (op.isGeneratee(opnum)) {
3468 assert(!gen && "Each CLI may have at most one def");
3469 gen = &use;
3470 } else if (op.isApplyee(opnum)) {
3471 assert(!cons && "Each CLI may have at most one consumer");
3472 cons = &use;
3473 } else {
3474 llvm_unreachable("Unexpected operand for a CLI");
3475 }
3476 }
3477
3478 return {create, gen, cons};
3479}
3480
3481void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3482 ::mlir::OperationState &odsState) {
3483 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3484}
3485
3486void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3487 Value result = getResult();
3488 auto [newCli, gen, cons] = decodeCli(result);
3489
3490 // Structured binding `gen` cannot be captured in lambdas before C++20
3491 OpOperand *generator = gen;
3492
3493 // Derive the CLI variable name from its generator:
3494 // * "canonloop" for omp.canonical_loop
3495 // * custom name for loop transformation generatees
3496 // * "cli" as fallback if no generator
3497 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3498 // at that level
3499 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3500 // the index of that region
3501 std::string cliName{"cli"};
3502 if (gen) {
3503 cliName =
3505 .Case([&](CanonicalLoopOp op) {
3506 return generateLoopNestingName("canonloop", op);
3507 })
3508 .Case([&](UnrollHeuristicOp op) -> std::string {
3509 llvm_unreachable("heuristic unrolling does not generate a loop");
3510 })
3511 .Case([&](FuseOp op) -> std::string {
3512 unsigned opnum = generator->getOperandNumber();
3513 // The position of the first loop to be fused is the same position
3514 // as the resulting fused loop
3515 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3516 return "canonloop_fuse";
3517 else
3518 return "fused";
3519 })
3520 .Case([&](TileOp op) -> std::string {
3521 auto [generateesFirst, generateesCount] =
3522 op.getGenerateesODSOperandIndexAndLength();
3523 unsigned firstGrid = generateesFirst;
3524 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3525 unsigned end = generateesFirst + generateesCount;
3526 unsigned opnum = generator->getOperandNumber();
3527 // In the OpenMP apply and looprange clauses, indices are 1-based
3528 if (firstGrid <= opnum && opnum < firstIntratile) {
3529 unsigned gridnum = opnum - firstGrid + 1;
3530 return ("grid" + Twine(gridnum)).str();
3531 }
3532 if (firstIntratile <= opnum && opnum < end) {
3533 unsigned intratilenum = opnum - firstIntratile + 1;
3534 return ("intratile" + Twine(intratilenum)).str();
3535 }
3536 llvm_unreachable("Unexpected generatee argument");
3537 })
3538 .DefaultUnreachable("TODO: Custom name for this operation");
3539 }
3540
3541 setNameFn(result, cliName);
3542}
3543
3544LogicalResult NewCliOp::verify() {
3545 Value cli = getResult();
3546
3547 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3548 "Unexpected type of cli");
3549
3550 // Check that the CLI is used in at most generator and one consumer
3551 OpOperand *gen = nullptr;
3552 OpOperand *cons = nullptr;
3553 for (mlir::OpOperand &use : cli.getUses()) {
3554 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3555
3556 unsigned opnum = use.getOperandNumber();
3557 if (op.isGeneratee(opnum)) {
3558 if (gen) {
3559 InFlightDiagnostic error =
3560 emitOpError("CLI must have at most one generator");
3561 error.attachNote(gen->getOwner()->getLoc())
3562 .append("first generator here:");
3563 error.attachNote(use.getOwner()->getLoc())
3564 .append("second generator here:");
3565 return error;
3566 }
3567
3568 gen = &use;
3569 } else if (op.isApplyee(opnum)) {
3570 if (cons) {
3571 InFlightDiagnostic error =
3572 emitOpError("CLI must have at most one consumer");
3573 error.attachNote(cons->getOwner()->getLoc())
3574 .append("first consumer here:")
3575 .appendOp(*cons->getOwner(),
3576 OpPrintingFlags().printGenericOpForm());
3577 error.attachNote(use.getOwner()->getLoc())
3578 .append("second consumer here:")
3579 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3580 return error;
3581 }
3582
3583 cons = &use;
3584 } else {
3585 llvm_unreachable("Unexpected operand for a CLI");
3586 }
3587 }
3588
3589 // If the CLI is source of a transformation, it must have a generator
3590 if (cons && !gen) {
3591 InFlightDiagnostic error = emitOpError("CLI has no generator");
3592 error.attachNote(cons->getOwner()->getLoc())
3593 .append("see consumer here: ")
3594 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3595 return error;
3596 }
3597
3598 return success();
3599}
3600
3601void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3602 Value tripCount) {
3603 odsState.addOperands(tripCount);
3604 odsState.addOperands(Value());
3605 (void)odsState.addRegion();
3606}
3607
3608void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3609 Value tripCount, ::mlir::Value cli) {
3610 odsState.addOperands(tripCount);
3611 odsState.addOperands(cli);
3612 (void)odsState.addRegion();
3613}
3614
3615void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3616 setNameFn(&getRegion().front(), "body_entry");
3617}
3618
3619void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3620 OpAsmSetValueNameFn setNameFn) {
3621 std::string ivName = generateLoopNestingName("iv", *this);
3622 setNameFn(region.getArgument(0), ivName);
3623}
3624
3625void CanonicalLoopOp::print(OpAsmPrinter &p) {
3626 if (getCli())
3627 p << '(' << getCli() << ')';
3628 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3629 << " in range(" << getTripCount() << ") ";
3630
3631 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3632 /*printBlockTerminators=*/true);
3633
3634 p.printOptionalAttrDict((*this)->getAttrs());
3635}
3636
3637mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3639 CanonicalLoopInfoType cliType =
3640 CanonicalLoopInfoType::get(parser.getContext());
3641
3642 // Parse (optional) omp.cli identifier
3644 SmallVector<mlir::Value, 1> cliOperand;
3645 if (!parser.parseOptionalLParen()) {
3646 if (parser.parseOperand(cli) ||
3647 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3648 return failure();
3649 }
3650
3651 // We derive the type of tripCount from inductionVariable. MLIR requires the
3652 // type of tripCount to be known when calling resolveOperand so we have parse
3653 // the type before processing the inductionVariable.
3654 OpAsmParser::Argument inductionVariable;
3656 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3657 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3658 parser.parseLParen() || parser.parseOperand(tripcount) ||
3659 parser.parseRParen() ||
3660 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3661 return failure();
3662
3663 // Parse the loop body.
3664 Region *region = result.addRegion();
3665 if (parser.parseRegion(*region, {inductionVariable}))
3666 return failure();
3667
3668 // We parsed the cli operand forst, but because it is optional, it must be
3669 // last in the operand list.
3670 result.operands.append(cliOperand);
3671
3672 // Parse the optional attribute list.
3673 if (parser.parseOptionalAttrDict(result.attributes))
3674 return failure();
3675
3676 return mlir::success();
3677}
3678
3679LogicalResult CanonicalLoopOp::verify() {
3680 // The region's entry must accept the induction variable
3681 // It can also be empty if just created
3682 if (!getRegion().empty()) {
3683 Region &region = getRegion();
3684 if (region.getNumArguments() != 1)
3685 return emitOpError(
3686 "Canonical loop region must have exactly one argument");
3687
3688 if (getInductionVar().getType() != getTripCount().getType())
3689 return emitOpError(
3690 "Region argument must be the same type as the trip count");
3691 }
3692
3693 return success();
3694}
3695
3696Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3697
3698std::pair<unsigned, unsigned>
3699CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3700 // No applyees
3701 return {0, 0};
3702}
3703
3704std::pair<unsigned, unsigned>
3705CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3706 return getODSOperandIndexAndLength(odsIndex_cli);
3707}
3708
3709//===----------------------------------------------------------------------===//
3710// UnrollHeuristicOp
3711//===----------------------------------------------------------------------===//
3712
3713void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3714 ::mlir::OperationState &odsState,
3715 ::mlir::Value cli) {
3716 odsState.addOperands(cli);
3717}
3718
3719void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3720 p << '(' << getApplyee() << ')';
3721
3722 p.printOptionalAttrDict((*this)->getAttrs());
3723}
3724
3725mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3727 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3728
3729 if (parser.parseLParen())
3730 return failure();
3731
3733 if (parser.parseOperand(applyee) ||
3734 parser.resolveOperand(applyee, cliType, result.operands))
3735 return failure();
3736
3737 if (parser.parseRParen())
3738 return failure();
3739
3740 // Optional output loop (full unrolling has none)
3741 if (!parser.parseOptionalArrow()) {
3742 if (parser.parseLParen() || parser.parseRParen())
3743 return failure();
3744 }
3745
3746 // Parse the optional attribute list.
3747 if (parser.parseOptionalAttrDict(result.attributes))
3748 return failure();
3749
3750 return mlir::success();
3751}
3752
3753std::pair<unsigned, unsigned>
3754UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3755 return getODSOperandIndexAndLength(odsIndex_applyee);
3756}
3757
3758std::pair<unsigned, unsigned>
3759UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3760 return {0, 0};
3761}
3762
3763//===----------------------------------------------------------------------===//
3764// TileOp
3765//===----------------------------------------------------------------------===//
3766
3767static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3768 OperandRange generatees,
3769 OperandRange applyees) {
3770 if (!generatees.empty())
3771 p << '(' << llvm::interleaved(generatees) << ')';
3772
3773 if (!applyees.empty())
3774 p << " <- (" << llvm::interleaved(applyees) << ')';
3775}
3776
3777static ParseResult parseLoopTransformClis(
3778 OpAsmParser &parser,
3781 if (parser.parseOptionalLess()) {
3782 // Syntax 1: generatees present
3783
3784 if (parser.parseOperandList(generateesOperands,
3786 return failure();
3787
3788 if (parser.parseLess())
3789 return failure();
3790 } else {
3791 // Syntax 2: generatees omitted
3792 }
3793
3794 // Parse `<-` (`<` has already been parsed)
3795 if (parser.parseMinus())
3796 return failure();
3797
3798 if (parser.parseOperandList(applyeesOperands,
3800 return failure();
3801
3802 return success();
3803}
3804
3805/// Check properties of the loop nest consisting of the transformation's
3806/// applyees:
3807/// 1. They are nested inside each other
3808/// 2. They are perfectly nested
3809/// (no code with side-effects in-between the loops)
3810/// 3. They are rectangular
3811/// (loop bounds are invariant in respect to the outer loops)
3812///
3813/// TODO: Generalize for LoopTransformationInterface.
3814static LogicalResult checkApplyeesNesting(TileOp op) {
3815 // Collect the loops from the nest
3816 bool isOnlyCanonLoops = true;
3818 for (Value applyee : op.getApplyees()) {
3819 auto [create, gen, cons] = decodeCli(applyee);
3820
3821 if (!gen)
3822 return op.emitOpError() << "applyee CLI has no generator";
3823
3824 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3825 canonLoops.push_back(loop);
3826 if (!loop)
3827 isOnlyCanonLoops = false;
3828 }
3829
3830 // FIXME: We currently can only verify non-rectangularity and perfect nest of
3831 // omp.canonical_loop.
3832 if (!isOnlyCanonLoops)
3833 return success();
3834
3835 DenseSet<Value> parentIVs;
3836 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
3837 auto parentLoop = canonLoops[i - 1];
3838 auto loop = canonLoops[i];
3839
3840 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
3841 return op.emitOpError()
3842 << "tiled loop nest must be nested within each other";
3843
3844 parentIVs.insert(parentLoop.getInductionVar());
3845
3846 // Canonical loop must be perfectly nested, i.e. the body of the parent must
3847 // only contain the omp.canonical_loop of the nested loops, and
3848 // omp.terminator
3849 bool isPerfectlyNested = [&]() {
3850 auto &parentBody = parentLoop.getRegion();
3851 if (!parentBody.hasOneBlock())
3852 return false;
3853 auto &parentBlock = parentBody.getBlocks().front();
3854
3855 auto nestedLoopIt = parentBlock.begin();
3856 if (nestedLoopIt == parentBlock.end() ||
3857 (&*nestedLoopIt != loop.getOperation()))
3858 return false;
3859
3860 auto termIt = std::next(nestedLoopIt);
3861 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3862 return false;
3863
3864 if (std::next(termIt) != parentBlock.end())
3865 return false;
3866
3867 return true;
3868 }();
3869 if (!isPerfectlyNested)
3870 return op.emitOpError() << "tiled loop nest must be perfectly nested";
3871
3872 if (parentIVs.contains(loop.getTripCount()))
3873 return op.emitOpError() << "tiled loop nest must be rectangular";
3874 }
3875
3876 // TODO: The tile sizes must be computed before the loop, but checking this
3877 // requires dominance analysis. For instance:
3878 //
3879 // %canonloop = omp.new_cli
3880 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3881 // // write to %x
3882 // omp.terminator
3883 // }
3884 // %ts = llvm.load %x
3885 // omp.tile <- (%canonloop) sizes(%ts : i32)
3886
3887 return success();
3888}
3889
3890LogicalResult TileOp::verify() {
3891 if (getApplyees().empty())
3892 return emitOpError() << "must apply to at least one loop";
3893
3894 if (getSizes().size() != getApplyees().size())
3895 return emitOpError() << "there must be one tile size for each applyee";
3896
3897 if (!getGeneratees().empty() &&
3898 2 * getSizes().size() != getGeneratees().size())
3899 return emitOpError()
3900 << "expecting two times the number of generatees than applyees";
3901
3902 return checkApplyeesNesting(*this);
3903}
3904
3905std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3906 return getODSOperandIndexAndLength(odsIndex_applyees);
3907}
3908
3909std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3910 return getODSOperandIndexAndLength(odsIndex_generatees);
3911}
3912
3913//===----------------------------------------------------------------------===//
3914// FuseOp
3915//===----------------------------------------------------------------------===//
3916
3917static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
3918 OperandRange generatees,
3919 OperandRange applyees) {
3920 if (!generatees.empty())
3921 p << '(' << llvm::interleaved(generatees) << ')';
3922
3923 if (!applyees.empty())
3924 p << " <- (" << llvm::interleaved(applyees) << ')';
3925}
3926
3927LogicalResult FuseOp::verify() {
3928 if (getApplyees().size() < 2)
3929 return emitOpError() << "must apply to at least two loops";
3930
3931 if (getFirst().has_value() && getCount().has_value()) {
3932 int64_t first = getFirst().value();
3933 int64_t count = getCount().value();
3934 if ((unsigned)(first + count - 1) > getApplyees().size())
3935 return emitOpError() << "the numbers of applyees must be at least first "
3936 "minus one plus count attributes";
3937 if (!getGeneratees().empty() &&
3938 getGeneratees().size() != getApplyees().size() + 1 - count)
3939 return emitOpError() << "the number of generatees must be the number of "
3940 "aplyees plus one minus count";
3941
3942 } else {
3943 if (!getGeneratees().empty() && getGeneratees().size() != 1)
3944 return emitOpError()
3945 << "in a complete fuse the number of generatees must be exactly 1";
3946 }
3947 for (auto &&applyee : getApplyees()) {
3948 auto [create, gen, cons] = decodeCli(applyee);
3949
3950 if (!gen)
3951 return emitOpError() << "applyee CLI has no generator";
3952 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3953 if (!loop)
3954 return emitOpError()
3955 << "currently only supports omp.canonical_loop as applyee";
3956 }
3957 return success();
3958}
3959std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
3960 return getODSOperandIndexAndLength(odsIndex_applyees);
3961}
3962
3963std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
3964 return getODSOperandIndexAndLength(odsIndex_generatees);
3965}
3966
3967//===----------------------------------------------------------------------===//
3968// Critical construct (2.17.1)
3969//===----------------------------------------------------------------------===//
3970
3971void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3972 const CriticalDeclareOperands &clauses) {
3973 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3974}
3975
3976LogicalResult CriticalDeclareOp::verify() {
3977 return verifySynchronizationHint(*this, getHint());
3978}
3979
3980LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3981 if (getNameAttr()) {
3982 SymbolRefAttr symbolRef = getNameAttr();
3983 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3984 *this, symbolRef);
3985 if (!decl) {
3986 return emitOpError() << "expected symbol reference " << symbolRef
3987 << " to point to a critical declaration";
3988 }
3989 }
3990
3991 return success();
3992}
3993
3994//===----------------------------------------------------------------------===//
3995// Ordered construct
3996//===----------------------------------------------------------------------===//
3997
3998static LogicalResult verifyOrderedParent(Operation &op) {
3999 bool hasRegion = op.getNumRegions() > 0;
4000 auto loopOp = op.getParentOfType<LoopNestOp>();
4001 if (!loopOp) {
4002 if (hasRegion)
4003 return success();
4004
4005 // TODO: Consider if this needs to be the case only for the standalone
4006 // variant of the ordered construct.
4007 return op.emitOpError() << "must be nested inside of a loop";
4008 }
4009
4010 Operation *wrapper = loopOp->getParentOp();
4011 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4012 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4013 if (!orderedAttr)
4014 return op.emitOpError() << "the enclosing worksharing-loop region must "
4015 "have an ordered clause";
4016
4017 if (hasRegion && orderedAttr.getInt() != 0)
4018 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4019 "have a parameter present";
4020
4021 if (!hasRegion && orderedAttr.getInt() == 0)
4022 return op.emitOpError() << "the enclosing loop's ordered clause must "
4023 "have a parameter present";
4024 } else if (!isa<SimdOp>(wrapper)) {
4025 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4026 "or worksharing simd loop";
4027 }
4028 return success();
4029}
4030
4031void OrderedOp::build(OpBuilder &builder, OperationState &state,
4032 const OrderedOperands &clauses) {
4033 OrderedOp::build(builder, state, clauses.doacrossDependType,
4034 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4035}
4036
4037LogicalResult OrderedOp::verify() {
4038 if (failed(verifyOrderedParent(**this)))
4039 return failure();
4040
4041 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4042 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4043 return emitOpError() << "number of variables in depend clause does not "
4044 << "match number of iteration variables in the "
4045 << "doacross loop";
4046
4047 return success();
4048}
4049
4050void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4051 const OrderedRegionOperands &clauses) {
4052 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4053}
4054
4055LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4056
4057//===----------------------------------------------------------------------===//
4058// TaskwaitOp
4059//===----------------------------------------------------------------------===//
4060
4061void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4062 const TaskwaitOperands &clauses) {
4063 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4064 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4065 /*depend_vars=*/{}, /*nowait=*/nullptr);
4066}
4067
4068//===----------------------------------------------------------------------===//
4069// Verifier for AtomicReadOp
4070//===----------------------------------------------------------------------===//
4071
4072LogicalResult AtomicReadOp::verify() {
4073 if (verifyCommon().failed())
4074 return mlir::failure();
4075
4076 if (auto mo = getMemoryOrder()) {
4077 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4078 *mo == ClauseMemoryOrderKind::Release) {
4079 return emitError(
4080 "memory-order must not be acq_rel or release for atomic reads");
4081 }
4082 }
4083 return verifySynchronizationHint(*this, getHint());
4084}
4085
4086//===----------------------------------------------------------------------===//
4087// Verifier for AtomicWriteOp
4088//===----------------------------------------------------------------------===//
4089
4090LogicalResult AtomicWriteOp::verify() {
4091 if (verifyCommon().failed())
4092 return mlir::failure();
4093
4094 if (auto mo = getMemoryOrder()) {
4095 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4096 *mo == ClauseMemoryOrderKind::Acquire) {
4097 return emitError(
4098 "memory-order must not be acq_rel or acquire for atomic writes");
4099 }
4100 }
4101 return verifySynchronizationHint(*this, getHint());
4102}
4103
4104//===----------------------------------------------------------------------===//
4105// Verifier for AtomicUpdateOp
4106//===----------------------------------------------------------------------===//
4107
4108LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4109 PatternRewriter &rewriter) {
4110 if (op.isNoOp()) {
4111 rewriter.eraseOp(op);
4112 return success();
4113 }
4114 if (Value writeVal = op.getWriteOpVal()) {
4115 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4116 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4117 return success();
4118 }
4119 return failure();
4120}
4121
4122LogicalResult AtomicUpdateOp::verify() {
4123 if (verifyCommon().failed())
4124 return mlir::failure();
4125
4126 if (auto mo = getMemoryOrder()) {
4127 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4128 *mo == ClauseMemoryOrderKind::Acquire) {
4129 return emitError(
4130 "memory-order must not be acq_rel or acquire for atomic updates");
4131 }
4132 }
4133
4134 return verifySynchronizationHint(*this, getHint());
4135}
4136
4137LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4138
4139//===----------------------------------------------------------------------===//
4140// Verifier for AtomicCaptureOp
4141//===----------------------------------------------------------------------===//
4142
4143AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4144 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4145 return op;
4146 return dyn_cast<AtomicReadOp>(getSecondOp());
4147}
4148
4149AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4150 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4151 return op;
4152 return dyn_cast<AtomicWriteOp>(getSecondOp());
4153}
4154
4155AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4156 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4157 return op;
4158 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4159}
4160
4161LogicalResult AtomicCaptureOp::verify() {
4162 return verifySynchronizationHint(*this, getHint());
4163}
4164
4165LogicalResult AtomicCaptureOp::verifyRegions() {
4166 if (verifyRegionsCommon().failed())
4167 return mlir::failure();
4168
4169 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4170 return emitOpError(
4171 "operations inside capture region must not have hint clause");
4172
4173 if (getFirstOp()->getAttr("memory_order") ||
4174 getSecondOp()->getAttr("memory_order"))
4175 return emitOpError(
4176 "operations inside capture region must not have memory_order clause");
4177 return success();
4178}
4179
4180//===----------------------------------------------------------------------===//
4181// CancelOp
4182//===----------------------------------------------------------------------===//
4183
4184void CancelOp::build(OpBuilder &builder, OperationState &state,
4185 const CancelOperands &clauses) {
4186 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4187}
4188
4190 Operation *parent = thisOp->getParentOp();
4191 while (parent) {
4192 if (parent->getDialect() == thisOp->getDialect())
4193 return parent;
4194 parent = parent->getParentOp();
4195 }
4196 return nullptr;
4197}
4198
4199LogicalResult CancelOp::verify() {
4200 ClauseCancellationConstructType cct = getCancelDirective();
4201 // The next OpenMP operation in the chain of parents
4202 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4203 if (!structuralParent)
4204 return emitOpError() << "Orphaned cancel construct";
4205
4206 if ((cct == ClauseCancellationConstructType::Parallel) &&
4207 !mlir::isa<ParallelOp>(structuralParent)) {
4208 return emitOpError() << "cancel parallel must appear "
4209 << "inside a parallel region";
4210 }
4211 if (cct == ClauseCancellationConstructType::Loop) {
4212 // structural parent will be omp.loop_nest, directly nested inside
4213 // omp.wsloop
4214 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4215
4216 if (!wsloopOp) {
4217 return emitOpError()
4218 << "cancel loop must appear inside a worksharing-loop region";
4219 }
4220 if (wsloopOp.getNowaitAttr()) {
4221 return emitError() << "A worksharing construct that is canceled "
4222 << "must not have a nowait clause";
4223 }
4224 if (wsloopOp.getOrderedAttr()) {
4225 return emitError() << "A worksharing construct that is canceled "
4226 << "must not have an ordered clause";
4227 }
4228
4229 } else if (cct == ClauseCancellationConstructType::Sections) {
4230 // structural parent will be an omp.section, directly nested inside
4231 // omp.sections
4232 auto sectionsOp =
4233 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4234 if (!sectionsOp) {
4235 return emitOpError() << "cancel sections must appear "
4236 << "inside a sections region";
4237 }
4238 if (sectionsOp.getNowait()) {
4239 return emitError() << "A sections construct that is canceled "
4240 << "must not have a nowait clause";
4241 }
4242 }
4243 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4244 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4245 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4246 return emitOpError() << "cancel taskgroup must appear "
4247 << "inside a task region";
4248 }
4249 return success();
4250}
4251
4252//===----------------------------------------------------------------------===//
4253// CancellationPointOp
4254//===----------------------------------------------------------------------===//
4255
4256void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4257 const CancellationPointOperands &clauses) {
4258 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4259}
4260
4261LogicalResult CancellationPointOp::verify() {
4262 ClauseCancellationConstructType cct = getCancelDirective();
4263 // The next OpenMP operation in the chain of parents
4264 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4265 if (!structuralParent)
4266 return emitOpError() << "Orphaned cancellation point";
4267
4268 if ((cct == ClauseCancellationConstructType::Parallel) &&
4269 !mlir::isa<ParallelOp>(structuralParent)) {
4270 return emitOpError() << "cancellation point parallel must appear "
4271 << "inside a parallel region";
4272 }
4273 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4274 // find the wsloop
4275 if ((cct == ClauseCancellationConstructType::Loop) &&
4276 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4277 return emitOpError() << "cancellation point loop must appear "
4278 << "inside a worksharing-loop region";
4279 }
4280 if ((cct == ClauseCancellationConstructType::Sections) &&
4281 !mlir::isa<omp::SectionOp>(structuralParent)) {
4282 return emitOpError() << "cancellation point sections must appear "
4283 << "inside a sections region";
4284 }
4285 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4286 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4287 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4288 return emitOpError() << "cancellation point taskgroup must appear "
4289 << "inside a task region";
4290 }
4291 return success();
4292}
4293
4294//===----------------------------------------------------------------------===//
4295// MapBoundsOp
4296//===----------------------------------------------------------------------===//
4297
4298LogicalResult MapBoundsOp::verify() {
4299 auto extent = getExtent();
4300 auto upperbound = getUpperBound();
4301 if (!extent && !upperbound)
4302 return emitError("expected extent or upperbound.");
4303 return success();
4304}
4305
4306void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4307 TypeRange /*result_types*/, StringAttr symName,
4308 TypeAttr type) {
4309 PrivateClauseOp::build(
4310 odsBuilder, odsState, symName, type,
4311 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4312 DataSharingClauseType::Private));
4313}
4314
4315LogicalResult PrivateClauseOp::verifyRegions() {
4316 Type argType = getArgType();
4317 auto verifyTerminator = [&](Operation *terminator,
4318 bool yieldsValue) -> LogicalResult {
4319 if (!terminator->getBlock()->getSuccessors().empty())
4320 return success();
4321
4322 if (!llvm::isa<YieldOp>(terminator))
4323 return mlir::emitError(terminator->getLoc())
4324 << "expected exit block terminator to be an `omp.yield` op.";
4325
4326 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4327 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4328
4329 if (!yieldsValue) {
4330 if (yieldedTypes.empty())
4331 return success();
4332
4333 return mlir::emitError(terminator->getLoc())
4334 << "Did not expect any values to be yielded.";
4335 }
4336
4337 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4338 return success();
4339
4340 auto error = mlir::emitError(yieldOp.getLoc())
4341 << "Invalid yielded value. Expected type: " << argType
4342 << ", got: ";
4343
4344 if (yieldedTypes.empty())
4345 error << "None";
4346 else
4347 error << yieldedTypes;
4348
4349 return error;
4350 };
4351
4352 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4353 StringRef regionName,
4354 bool yieldsValue) -> LogicalResult {
4355 assert(!region.empty());
4356
4357 if (region.getNumArguments() != expectedNumArgs)
4358 return mlir::emitError(region.getLoc())
4359 << "`" << regionName << "`: "
4360 << "expected " << expectedNumArgs
4361 << " region arguments, got: " << region.getNumArguments();
4362
4363 for (Block &block : region) {
4364 // MLIR will verify the absence of the terminator for us.
4365 if (!block.mightHaveTerminator())
4366 continue;
4367
4368 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4369 return failure();
4370 }
4371
4372 return success();
4373 };
4374
4375 // Ensure all of the region arguments have the same type
4376 for (Region *region : getRegions())
4377 for (Type ty : region->getArgumentTypes())
4378 if (ty != argType)
4379 return emitError() << "Region argument type mismatch: got " << ty
4380 << " expected " << argType << ".";
4381
4382 mlir::Region &initRegion = getInitRegion();
4383 if (!initRegion.empty() &&
4384 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4385 /*yieldsValue=*/true)))
4386 return failure();
4387
4388 DataSharingClauseType dsType = getDataSharingType();
4389
4390 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4391 return emitError("`private` clauses do not require a `copy` region.");
4392
4393 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4394 return emitError(
4395 "`firstprivate` clauses require at least a `copy` region.");
4396
4397 if (dsType == DataSharingClauseType::FirstPrivate &&
4398 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4399 /*yieldsValue=*/true)))
4400 return failure();
4401
4402 if (!getDeallocRegion().empty() &&
4403 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4404 /*yieldsValue=*/false)))
4405 return failure();
4406
4407 return success();
4408}
4409
4410//===----------------------------------------------------------------------===//
4411// Spec 5.2: Masked construct (10.5)
4412//===----------------------------------------------------------------------===//
4413
4414void MaskedOp::build(OpBuilder &builder, OperationState &state,
4415 const MaskedOperands &clauses) {
4416 MaskedOp::build(builder, state, clauses.filteredThreadId);
4417}
4418
4419//===----------------------------------------------------------------------===//
4420// Spec 5.2: Scan construct (5.6)
4421//===----------------------------------------------------------------------===//
4422
4423void ScanOp::build(OpBuilder &builder, OperationState &state,
4424 const ScanOperands &clauses) {
4425 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4426}
4427
4428LogicalResult ScanOp::verify() {
4429 if (hasExclusiveVars() == hasInclusiveVars())
4430 return emitError(
4431 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4432 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4433 if (parentWsLoopOp.getReductionModAttr() &&
4434 parentWsLoopOp.getReductionModAttr().getValue() ==
4435 ReductionModifier::inscan)
4436 return success();
4437 }
4438 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4439 if (parentSimdOp.getReductionModAttr() &&
4440 parentSimdOp.getReductionModAttr().getValue() ==
4441 ReductionModifier::inscan)
4442 return success();
4443 }
4444 return emitError("SCAN directive needs to be enclosed within a parent "
4445 "worksharing loop construct or SIMD construct with INSCAN "
4446 "reduction modifier");
4447}
4448
4449/// Verifies align clause in allocate directive
4450
4451LogicalResult AllocateDirOp::verify() {
4452 std::optional<uint64_t> align = this->getAlign();
4453
4454 if (align.has_value()) {
4455 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4456 return emitError() << "ALIGN value : " << align.value()
4457 << " must be power of 2";
4458 }
4459
4460 return success();
4461}
4462
4463//===----------------------------------------------------------------------===//
4464// TargetAllocMemOp
4465//===----------------------------------------------------------------------===//
4466
4467mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4468 return getInTypeAttr().getValue();
4469}
4470
4471/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4472/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4473/// attr-dict-without-keyword
4474static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4476 auto &builder = parser.getBuilder();
4477 bool hasOperands = false;
4478 std::int32_t typeparamsSize = 0;
4479
4480 // Parse device number as a new operand
4482 mlir::Type deviceType;
4483 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4484 return mlir::failure();
4485 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4486 return mlir::failure();
4487 if (parser.parseComma())
4488 return mlir::failure();
4489
4490 mlir::Type intype;
4491 if (parser.parseType(intype))
4492 return mlir::failure();
4493 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4496 if (!parser.parseOptionalLParen()) {
4497 // parse the LEN params of the derived type. (<params> : <types>)
4499 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4500 return mlir::failure();
4501 typeparamsSize = operands.size();
4502 hasOperands = true;
4503 }
4504 std::int32_t shapeSize = 0;
4505 if (!parser.parseOptionalComma()) {
4506 // parse size to scale by, vector of n dimensions of type index
4508 return mlir::failure();
4509 shapeSize = operands.size() - typeparamsSize;
4510 auto idxTy = builder.getIndexType();
4511 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4512 typeVec.push_back(idxTy);
4513 hasOperands = true;
4514 }
4515 if (hasOperands &&
4516 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4517 result.operands))
4518 return mlir::failure();
4519
4520 mlir::Type restype = builder.getIntegerType(64);
4521 if (!restype) {
4522 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4523 return mlir::failure();
4524 }
4525 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4526 result.addAttribute("operandSegmentSizes",
4527 builder.getDenseI32ArrayAttr(segmentSizes));
4528 if (parser.parseOptionalAttrDict(result.attributes) ||
4529 parser.addTypeToList(restype, result.types))
4530 return mlir::failure();
4531 return mlir::success();
4532}
4533
4534mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4536 return parseTargetAllocMemOp(parser, result);
4537}
4538
4539void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4540 p << " ";
4542 p << " : ";
4543 p << getDevice().getType();
4544 p << ", ";
4545 p << getInType();
4546 if (!getTypeparams().empty()) {
4547 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4548 }
4549 for (auto sh : getShape()) {
4550 p << ", ";
4551 p.printOperand(sh);
4552 }
4553 p.printOptionalAttrDict((*this)->getAttrs(),
4554 {"in_type", "operandSegmentSizes"});
4555}
4556
4557llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4558 mlir::Type outType = getType();
4559 if (!mlir::dyn_cast<IntegerType>(outType))
4560 return emitOpError("must be a integer type");
4561 return mlir::success();
4562}
4563
4564//===----------------------------------------------------------------------===//
4565// WorkdistributeOp
4566//===----------------------------------------------------------------------===//
4567
4568LogicalResult WorkdistributeOp::verify() {
4569 // Check that region exists and is not empty
4570 Region &region = getRegion();
4571 if (region.empty())
4572 return emitOpError("region cannot be empty");
4573 // Verify single entry point.
4574 Block &entryBlock = region.front();
4575 if (entryBlock.empty())
4576 return emitOpError("region must contain a structured block");
4577 // Verify single exit point.
4578 bool hasTerminator = false;
4579 for (Block &block : region) {
4580 if (isa<TerminatorOp>(block.back())) {
4581 if (hasTerminator) {
4582 return emitOpError("region must have exactly one terminator");
4583 }
4584 hasTerminator = true;
4585 }
4586 }
4587 if (!hasTerminator) {
4588 return emitOpError("region must be terminated with omp.terminator");
4589 }
4590 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4591 // No implicit barrier at end
4592 if (isa<BarrierOp>(op)) {
4593 return emitOpError(
4594 "explicit barriers are not allowed in workdistribute region");
4595 }
4596 // Check for invalid nested constructs
4597 if (isa<ParallelOp>(op)) {
4598 return emitOpError(
4599 "nested parallel constructs not allowed in workdistribute");
4600 }
4601 if (isa<TeamsOp>(op)) {
4602 return emitOpError(
4603 "nested teams constructs not allowed in workdistribute");
4604 }
4605 return WalkResult::advance();
4606 });
4607 if (walkResult.wasInterrupted())
4608 return failure();
4609
4610 Operation *parentOp = (*this)->getParentOp();
4611 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4612 return emitOpError("workdistribute must be nested under teams");
4613 return success();
4614}
4615
4616//===----------------------------------------------------------------------===//
4617// Declare simd [7.7]
4618//===----------------------------------------------------------------------===//
4619
4620LogicalResult DeclareSimdOp::verify() {
4621 // Must be nested inside a function-like op
4622 auto func =
4623 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4624 if (!func)
4625 return emitOpError() << "must be nested inside a function";
4626
4627 if (getInbranch() && getNotinbranch())
4628 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4629
4630 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4631}
4632
4633void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4634 const DeclareSimdOperands &clauses) {
4635 MLIRContext *ctx = odsBuilder.getContext();
4636 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4637 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4638 clauses.linearVars, clauses.linearStepVars,
4639 clauses.linearVarTypes, clauses.notinbranch,
4640 clauses.simdlen, clauses.uniformVars);
4641}
4642
4643//===----------------------------------------------------------------------===//
4644// Parser and printer for Uniform Clause
4645//===----------------------------------------------------------------------===//
4646
4647/// uniform ::= `uniform` `(` uniform-list `)`
4648/// uniform-list := uniform-val (`,` uniform-val)*
4649/// uniform-val := ssa-id `:` type
4650static ParseResult
4653 SmallVectorImpl<Type> &uniformTypes) {
4654 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
4655 if (parser.parseOperand(uniformVars.emplace_back()) ||
4656 parser.parseColonType(uniformTypes.emplace_back()))
4657 return mlir::failure();
4658 return mlir::success();
4659 });
4660}
4661
4662/// Print Uniform Clauses
4664 ValueRange uniformVars, TypeRange uniformTypes) {
4665 for (unsigned i = 0; i < uniformVars.size(); ++i) {
4666 if (i != 0)
4667 p << ", ";
4668 p << uniformVars[i] << " : " << uniformTypes[i];
4669 }
4670}
4671
4672//===----------------------------------------------------------------------===//
4673// Parser and printer for Affinity Clause
4674//===----------------------------------------------------------------------===//
4675
4676static ParseResult parseAffinityClause(
4677 OpAsmParser &parser,
4680 SmallVectorImpl<Type> &iteratedTypes,
4681 SmallVectorImpl<Type> &affinityVarTypes) {
4682 if (failed(parseSplitIteratedList(
4683 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4684 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
4685 return failure();
4686 return success();
4687}
4688
4690 ValueRange iterated, ValueRange affinityVars,
4691 TypeRange iteratedTypes,
4692 TypeRange affinityVarTypes) {
4693 auto nop = [&](Value, Type) {};
4694 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
4695 affinityVarTypes,
4696 /*plain prefix*/ nop,
4697 /*iterated prefix*/ nop);
4698}
4699
4700//===----------------------------------------------------------------------===//
4701// Parser, printer, and verifier for Iterator modifier
4702//===----------------------------------------------------------------------===//
4703
4704static ParseResult
4709 SmallVectorImpl<Type> &lbTypes,
4710 SmallVectorImpl<Type> &ubTypes,
4711 SmallVectorImpl<Type> &stepTypes) {
4712
4713 llvm::SMLoc ivLoc = parser.getCurrentLocation();
4715
4716 // Parse induction variables: %i : i32, %j : i32
4717 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4718 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4719 if (parser.parseArgument(arg))
4720 return failure();
4721
4722 // Optional type, default to Index if not provided
4723 if (succeeded(parser.parseOptionalColon())) {
4724 if (parser.parseType(arg.type))
4725 return failure();
4726 } else {
4727 arg.type = parser.getBuilder().getIndexType();
4728 }
4729 return success();
4730 }))
4731 return failure();
4732
4733 // ) = (
4734 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
4735 return failure();
4736
4737 // Parse Ranges: (%lb to %ub step %st, ...)
4738 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4739 OpAsmParser::UnresolvedOperand lb, ub, st;
4740 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
4741 parser.parseOperand(ub) || parser.parseKeyword("step") ||
4742 parser.parseOperand(st))
4743 return failure();
4744
4745 lbs.push_back(lb);
4746 ubs.push_back(ub);
4747 steps.push_back(st);
4748 return success();
4749 }))
4750 return failure();
4751
4752 if (parser.parseRParen())
4753 return failure();
4754
4755 if (ivArgs.size() != lbs.size())
4756 return parser.emitError(ivLoc)
4757 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
4758 << " ranges";
4759
4760 for (auto &arg : ivArgs) {
4761 lbTypes.push_back(arg.type);
4762 ubTypes.push_back(arg.type);
4763 stepTypes.push_back(arg.type);
4764 }
4765
4766 return parser.parseRegion(region, ivArgs);
4767}
4768
4770 ValueRange lbs, ValueRange ubs,
4772 TypeRange) {
4773 Block &entry = region.front();
4774
4775 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
4776 if (i != 0)
4777 p << ", ";
4778 p.printRegionArgument(entry.getArgument(i));
4779 }
4780 p << ") = (";
4781
4782 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
4783 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
4784 if (i)
4785 p << ", ";
4786 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
4787 }
4788 p << ") ";
4789
4790 p.printRegion(region, /*printEntryBlockArgs=*/false,
4791 /*printBlockTerminators=*/true);
4792}
4793
4794LogicalResult IteratorOp::verify() {
4795 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
4796 if (!iteratedTy)
4797 return emitOpError() << "result must be omp.iterated<entry_ty>";
4798
4799 Block &b = getRegion().front();
4800 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
4801
4802 if (!yield)
4803 return emitOpError() << "region must be terminated by omp.yield";
4804
4805 if (yield.getNumOperands() != 1)
4806 return emitOpError()
4807 << "omp.yield in omp.iterator region must yield exactly one value";
4808
4809 mlir::Type yieldedTy = yield.getOperand(0).getType();
4810 mlir::Type elemTy = iteratedTy.getElementType();
4811
4812 if (yieldedTy != elemTy)
4813 return emitOpError() << "omp.iterated element type (" << elemTy
4814 << ") does not match omp.yield operand type ("
4815 << yieldedTy << ")";
4816
4817 return success();
4818}
4819
4820#define GET_ATTRDEF_CLASSES
4821#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4822
4823#define GET_OP_CLASSES
4824#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4825
4826#define GET_TYPEDEF_CLASSES
4827#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes)
Print Linear Clause.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:251
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.